From 76d464e0c2086bbbeef8d94d9fc070e3bfe92190 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 02:53:46 -0500 Subject: [PATCH 01/26] baml_language: add throws field to function type representations Add `throws: Box` to `Ty::Function` across all compiler layers (parser, CST, AST, TIR, normalization, MIR, runtime type) so the type system can represent `() -> T throws E`. Defaults to `Never` when omitted. Includes covariant subtyping for throws and test coverage for the new syntax. --- .../crates/baml_compiler2_ast/src/ast.rs | 3 +- .../crates/baml_compiler2_ast/src/lib.rs | 8 ++- .../baml_compiler2_ast/src/lower_type_expr.rs | 7 +++ .../crates/baml_compiler2_mir/src/lower.rs | 13 ++++- .../crates/baml_compiler2_tir/src/builder.rs | 18 ++++++ .../crates/baml_compiler2_tir/src/generics.rs | 58 ++++++++++++++++--- .../baml_compiler2_tir/src/lower_type_expr.rs | 24 +++++++- .../baml_compiler2_tir/src/normalize.rs | 36 +++++++++++- .../crates/baml_compiler2_tir/src/ty.rs | 16 ++++- .../crates/baml_compiler_parser/src/parser.rs | 7 +++ .../crates/baml_compiler_syntax/src/ast.rs | 8 +++ .../projects/function_type_throws/basic.baml | 8 +++ .../baml_tests____baml_std____04_tir.snap | 4 +- ...baml_tests____testing_std____04_5_mir.snap | 4 +- ...function_type_throws__01_lexer__basic.snap | 53 +++++++++++++++++ ...unction_type_throws__02_parser__basic.snap | 53 +++++++++++++++++ ...l_tests__function_type_throws__03_hir.snap | 8 +++ ...tests__function_type_throws__04_5_mir.snap | 21 +++++++ ...l_tests__function_type_throws__04_tir.snap | 11 ++++ ..._function_type_throws__05_diagnostics.snap | 5 ++ ...sts__function_type_throws__06_codegen.snap | 8 +++ ...tion_type_throws__10_formatter__basic.snap | 5 ++ baml_language/crates/baml_type/src/lib.rs | 15 ++++- 23 files changed, 369 insertions(+), 24 deletions(-) create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/basic.baml create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__basic.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__basic.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__basic.snap diff --git a/baml_language/crates/baml_compiler2_ast/src/ast.rs b/baml_language/crates/baml_compiler2_ast/src/ast.rs index bd5d117bd8..298ba428b2 100644 --- a/baml_language/crates/baml_compiler2_ast/src/ast.rs +++ b/baml_language/crates/baml_compiler2_ast/src/ast.rs @@ -97,10 +97,11 @@ pub enum TypeExpr { value: baml_base::Literal, attrs: Vec, }, - /// Function type: (params) -> return + /// Function type: (params) -> return [throws E] Function { params: Vec, ret: Box, + throws: Option>, attrs: Vec, }, /// The `unknown` keyword type diff --git a/baml_language/crates/baml_compiler2_ast/src/lib.rs b/baml_language/crates/baml_compiler2_ast/src/lib.rs index 00d553d95c..62872e5753 100644 --- a/baml_language/crates/baml_compiler2_ast/src/lib.rs +++ b/baml_language/crates/baml_compiler2_ast/src/lib.rs @@ -178,7 +178,12 @@ mod tests { value: value.clone(), attrs: strip_attrs(attrs), }, - TypeExpr::Function { params, ret, attrs } => TypeExpr::Function { + TypeExpr::Function { + params, + ret, + throws, + attrs, + } => TypeExpr::Function { params: params .iter() .map(|p| crate::ast::FunctionTypeParam { @@ -187,6 +192,7 @@ mod tests { }) .collect(), ret: Box::new(strip_spans(ret)), + throws: throws.as_ref().map(|t| Box::new(strip_spans(t))), attrs: strip_attrs(attrs), }, TypeExpr::Media { kind, attrs } => TypeExpr::Media { diff --git a/baml_language/crates/baml_compiler2_ast/src/lower_type_expr.rs b/baml_language/crates/baml_compiler2_ast/src/lower_type_expr.rs index a8e0af63fd..61bfbaf971 100644 --- a/baml_language/crates/baml_compiler2_ast/src/lower_type_expr.rs +++ b/baml_language/crates/baml_compiler2_ast/src/lower_type_expr.rs @@ -164,9 +164,16 @@ fn lower_base_terminal(type_expr: &CstTypeExpr) -> TypeExpr { .function_return_type() .map(|t| lower_type_expr_inner(&t, false)) .unwrap_or(TypeExpr::Unknown { attrs: vec![] }); + let throws = type_expr.function_throws_type().map(|throws_clause| { + let throws_type = throws_clause + .type_expr() + .expect("THROWS_CLAUSE should have a type"); + Box::new(lower_type_expr_inner(&throws_type, false)) + }); return TypeExpr::Function { params, ret: Box::new(ret), + throws, attrs: vec![], }; } diff --git a/baml_language/crates/baml_compiler2_mir/src/lower.rs b/baml_language/crates/baml_compiler2_mir/src/lower.rs index eb1cd98552..d230fe8273 100644 --- a/baml_language/crates/baml_compiler2_mir/src/lower.rs +++ b/baml_language/crates/baml_compiler2_mir/src/lower.rs @@ -164,13 +164,22 @@ pub fn convert_tir2_ty(ty: &Tir2Ty, resolved: &ResolvedAliases) -> Ty { attr: attr.clone(), }, - // Functions — drop param names - Tir2Ty::Function { params, ret, attr } => Ty::Function { + // Functions — drop param names; map Never throws to None + Tir2Ty::Function { + params, + ret, + throws, + attr, + } => Ty::Function { params: params .iter() .map(|(_, t)| convert_tir2_ty(t, resolved)) .collect(), ret: Box::new(convert_tir2_ty(ret, resolved)), + throws: match throws.as_ref() { + Tir2Ty::Never { .. } => None, + t => Some(Box::new(convert_tir2_ty(t, resolved))), + }, attr: attr.clone(), }, diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index 906caa4fe4..aed58bd3b9 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -725,6 +725,9 @@ impl<'db> TypeInferenceBuilder<'db> { Ty::Function { params: param_tys, ret: Box::new(ret_ty), + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), attr: TyAttr::default(), } } @@ -1174,6 +1177,9 @@ impl<'db> TypeInferenceBuilder<'db> { let result = Ty::Function { params: param_tys, ret: Box::new(ret_ty), + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), attr: TyAttr::default(), }; self.record_expr_type(expr_id, result.clone()); @@ -3042,6 +3048,9 @@ impl<'db> TypeInferenceBuilder<'db> { attr: TyAttr::default(), }), ), + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), attr: TyAttr::default(), }; return Some(ty); @@ -3130,6 +3139,9 @@ impl<'db> TypeInferenceBuilder<'db> { attr: TyAttr::default(), }), ), + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), attr: TyAttr::default(), } } @@ -3635,6 +3647,9 @@ impl<'db> TypeInferenceBuilder<'db> { attr: TyAttr::default(), }), ), + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), attr: TyAttr::default(), }; // Note: diags from method signatures are reported at definition site. @@ -4043,6 +4058,9 @@ impl<'db> TypeInferenceBuilder<'db> { ty: Ty::Function { params, ret: Box::new(ret), + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), attr: TyAttr::default(), }, class_loc, diff --git a/baml_language/crates/baml_compiler2_tir/src/generics.rs b/baml_language/crates/baml_compiler2_tir/src/generics.rs index 563ff93656..595775118c 100644 --- a/baml_language/crates/baml_compiler2_tir/src/generics.rs +++ b/baml_language/crates/baml_compiler2_tir/src/generics.rs @@ -73,12 +73,18 @@ pub fn substitute_ty(ty: &Ty, bindings: &FxHashMap) -> Ty { members.iter().map(|m| substitute_ty(m, bindings)).collect(), attr.clone(), ), - Ty::Function { params, ret, attr } => Ty::Function { + Ty::Function { + params, + ret, + throws, + attr, + } => Ty::Function { params: params .iter() .map(|(n, t)| (n.clone(), substitute_ty(t, bindings))) .collect(), ret: Box::new(substitute_ty(ret, bindings)), + throws: Box::new(substitute_ty(throws, bindings)), attr: attr.clone(), }, // All other types are leaves (primitives, class refs, enums, etc.) — pass through. @@ -200,7 +206,12 @@ pub fn lower_type_expr_with_generics( .collect(), TyAttr::default(), ), - TypeExpr::Function { params, ret, .. } => Ty::Function { + TypeExpr::Function { + params, + ret, + throws, + .. + } => Ty::Function { params: params .iter() .map(|p| { @@ -225,6 +236,23 @@ pub fn lower_type_expr_with_generics( bindings, diagnostics, )), + throws: Box::new( + throws + .as_ref() + .map(|t| { + lower_type_expr_with_generics( + db, + t, + package_items, + ns_context, + bindings, + diagnostics, + ) + }) + .unwrap_or_else(|| Ty::Never { + attr: TyAttr::default(), + }), + ), attr: TyAttr::default(), }, // For all other type expressions (primitives, multi-segment paths, etc.), @@ -265,8 +293,15 @@ pub fn contains_typevar(ty: &Ty) -> bool { } Ty::Map(k, v, _) | Ty::EvolvingMap(k, v, _) => contains_typevar(k) || contains_typevar(v), Ty::Union(tys, _) => tys.iter().any(contains_typevar), - Ty::Function { params, ret, .. } => { - params.iter().any(|(_, t)| contains_typevar(t)) || contains_typevar(ret) + Ty::Function { + params, + ret, + throws, + .. + } => { + params.iter().any(|(_, t)| contains_typevar(t)) + || contains_typevar(ret) + || contains_typevar(throws) } _ => false, } @@ -301,18 +336,21 @@ pub fn infer_bindings(formal: &Ty, actual: &Ty, bindings: &mut FxHashMap { - for ((_, ft), (_, at)) in fp.iter().zip(ap.iter()) { - infer_bindings(ft, at, bindings); + for ((_, fpt), (_, apt)) in fp.iter().zip(ap.iter()) { + infer_bindings(fpt, apt, bindings); } infer_bindings(fr, ar, bindings); + infer_bindings(fthrows, athrows, bindings); } _ => {} // Concrete types: nothing to infer } @@ -385,12 +423,18 @@ pub fn erase_unresolved_typevars( Box::new(erase_unresolved_typevars(inner, diagnostics)), attr.clone(), ), - Ty::Function { params, ret, attr } => Ty::Function { + Ty::Function { + params, + ret, + throws, + attr, + } => Ty::Function { params: params .iter() .map(|(n, t)| (n.clone(), erase_unresolved_typevars(t, diagnostics))) .collect(), ret: Box::new(erase_unresolved_typevars(ret, diagnostics)), + throws: Box::new(erase_unresolved_typevars(throws, diagnostics)), attr: attr.clone(), }, Ty::Union(tys, attr) => Ty::Union( diff --git a/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs b/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs index 74895daa5a..cd785a78d5 100644 --- a/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs +++ b/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs @@ -220,7 +220,12 @@ pub fn lower_type_expr_in_ns( .collect(), TyAttr::default(), ), - TypeExpr::Function { params, ret, .. } => Ty::Function { + TypeExpr::Function { + params, + ret, + throws, + .. + } => Ty::Function { params: params .iter() .map(|p| { @@ -245,6 +250,23 @@ pub fn lower_type_expr_in_ns( generic_params, diagnostics, )), + throws: Box::new( + throws + .as_ref() + .map(|t| { + lower_type_expr_in_ns( + db, + t, + package_items, + ns_context, + generic_params, + diagnostics, + ) + }) + .unwrap_or_else(|| Ty::Never { + attr: TyAttr::default(), + }), + ), attr: TyAttr::default(), }, TypeExpr::Literal { value: lit, .. } => { diff --git a/baml_language/crates/baml_compiler2_tir/src/normalize.rs b/baml_language/crates/baml_compiler2_tir/src/normalize.rs index dd63225415..d3e8771aec 100644 --- a/baml_language/crates/baml_compiler2_tir/src/normalize.rs +++ b/baml_language/crates/baml_compiler2_tir/src/normalize.rs @@ -72,6 +72,7 @@ enum StructuralTy { Function { params: Vec, ret: Box, + throws: Box, }, // Recursion Mu { @@ -229,20 +230,26 @@ impl StructuralTy { // EnumVariant(E, V) <: Enum(E) (StructuralTy::EnumVariant(e, _), StructuralTy::Enum(sup_e)) => e == sup_e, - // Function subtyping: contravariant params, covariant return + // Function subtyping: contravariant params, covariant return and throws ( StructuralTy::Function { params: params1, ret: ret1, + throws: throws1, }, StructuralTy::Function { params: params2, ret: ret2, + throws: throws2, }, ) => { if !ret1.is_subtype_of(ret2, assumptions) { return false; } + // Throws is covariant: fewer throws is more specific + if !throws1.is_subtype_of(throws2, assumptions) { + return false; + } if params2.len() > params1.len() { return false; } @@ -286,12 +293,17 @@ fn substitute( .map(|t| substitute(t, var, replacement)) .collect(), ), - StructuralTy::Function { params, ret } => StructuralTy::Function { + StructuralTy::Function { + params, + ret, + throws, + } => StructuralTy::Function { params: params .iter() .map(|t| substitute(t, var, replacement)) .collect(), ret: Box::new(substitute(ret, var, replacement)), + throws: Box::new(substitute(throws, var, replacement)), }, StructuralTy::Mu { var: v, body } if v != var => StructuralTy::Mu { var: v.clone(), @@ -381,12 +393,18 @@ fn normalize_impl( .map(|t| normalize_impl(t, aliases, recursive, expanding)) .collect(), ), - Ty::Function { params, ret, .. } => StructuralTy::Function { + Ty::Function { + params, + ret, + throws, + .. + } => StructuralTy::Function { params: params .iter() .map(|(_, t)| normalize_impl(t, aliases, recursive, expanding)) .collect(), ret: Box::new(normalize_impl(ret, aliases, recursive, expanding)), + throws: Box::new(normalize_impl(throws, aliases, recursive, expanding)), }, Ty::TypeVar(name, _) => StructuralTy::TypeVar(name.clone()), // `$rust_type` — opaque Rust-managed state. Treated as Unknown @@ -1147,11 +1165,17 @@ mod tests { let f1 = Ty::Function { params: vec![(None, Ty::Primitive(PrimitiveType::Int, TyAttr::default()))], ret: Box::new(Ty::Primitive(PrimitiveType::Int, TyAttr::default())), + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), attr: TyAttr::default(), }; let f2 = Ty::Function { params: vec![(None, Ty::Primitive(PrimitiveType::Int, TyAttr::default()))], ret: Box::new(Ty::Primitive(PrimitiveType::Float, TyAttr::default())), + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), attr: TyAttr::default(), }; assert!(is_subtype_of(&f1, &f2, &aliases)); @@ -1164,11 +1188,17 @@ mod tests { let f1 = Ty::Function { params: vec![(None, Ty::Primitive(PrimitiveType::Float, TyAttr::default()))], ret: Box::new(Ty::Primitive(PrimitiveType::String, TyAttr::default())), + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), attr: TyAttr::default(), }; let f2 = Ty::Function { params: vec![(None, Ty::Primitive(PrimitiveType::Int, TyAttr::default()))], ret: Box::new(Ty::Primitive(PrimitiveType::String, TyAttr::default())), + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), attr: TyAttr::default(), }; assert!(is_subtype_of(&f1, &f2, &aliases)); diff --git a/baml_language/crates/baml_compiler2_tir/src/ty.rs b/baml_language/crates/baml_compiler2_tir/src/ty.rs index a96d8f44ae..90c5777959 100644 --- a/baml_language/crates/baml_compiler2_tir/src/ty.rs +++ b/baml_language/crates/baml_compiler2_tir/src/ty.rs @@ -159,10 +159,11 @@ pub enum Ty { /// Evolving map — created from empty map literal at mutable binding sites. /// Same semantics as `EvolvingList` but for maps (see doc on `EvolvingList`). EvolvingMap(Box, Box, TyAttr), - /// Function type: (params) -> return. + /// Function type: (params) -> return [throws E]. Function { params: Vec<(Option, Ty)>, ret: Box, + throws: Box, attr: TyAttr, }, /// A type variable (generic parameter) — e.g. `T` in `Array`. @@ -491,7 +492,12 @@ impl fmt::Display for Ty { write!(f, "?") } Ty::Literal(lit, _freshness, _) => write!(f, "{lit}"), - Ty::Function { params, ret, .. } => { + Ty::Function { + params, + ret, + throws, + .. + } => { let ps: Vec = params .iter() .map(|(name, ty)| { @@ -500,7 +506,11 @@ impl fmt::Display for Ty { .unwrap_or_else(|| ty.to_string()) }) .collect(); - write!(f, "({}) -> {ret}", ps.join(", ")) + write!(f, "({}) -> {ret}", ps.join(", "))?; + if !matches!(throws.as_ref(), Ty::Never { .. }) { + write!(f, " throws {throws}")?; + } + Ok(()) } Ty::TypeVar(name, _) => write!(f, "{name}"), Ty::Never { .. } => write!(f, "never"), diff --git a/baml_language/crates/baml_compiler_parser/src/parser.rs b/baml_language/crates/baml_compiler_parser/src/parser.rs index 6b1ebe0b80..66be3fd7a4 100644 --- a/baml_language/crates/baml_compiler_parser/src/parser.rs +++ b/baml_language/crates/baml_compiler_parser/src/parser.rs @@ -2011,6 +2011,13 @@ impl<'a> Parser<'a> { // Note: The tokens are already emitted, we just need to parse the return type self.bump(); // -> self.parse_type(); // return type + // Optional throws clause on function type + if self.at(TokenKind::Throws) { + self.with_node(SyntaxKind::THROWS_CLAUSE, |p| { + p.bump(); // throws + p.parse_type(); + }); + } // The caller's with_node(TYPE_EXPR) will wrap this appropriately } else { // Not a function type - should be a parenthesized type diff --git a/baml_language/crates/baml_compiler_syntax/src/ast.rs b/baml_language/crates/baml_compiler_syntax/src/ast.rs index 5c5d4a39b1..ff0b73acbd 100644 --- a/baml_language/crates/baml_compiler_syntax/src/ast.rs +++ b/baml_language/crates/baml_compiler_syntax/src/ast.rs @@ -2760,6 +2760,14 @@ impl ThrowsClause { } } +impl TypeExpr { + /// Returns the throws clause of a function type expression, if present. + /// Only meaningful when `self.is_function_type()` is true. + pub fn function_throws_type(&self) -> Option { + self.syntax.children().find_map(ThrowsClause::cast) + } +} + impl CatchExpr { /// Get the base expression before the first catch clause. pub fn base(&self) -> Option { diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/basic.baml b/baml_language/crates/baml_tests/projects/function_type_throws/basic.baml new file mode 100644 index 0000000000..fd41e0f48f --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/basic.baml @@ -0,0 +1,8 @@ +// Test: function type with throws clause +// This verifies that () -> int throws string parses correctly + +type ThrowingFn = () -> int throws string + +function takes_throwing(f: () -> int throws string) -> int { + f() +} diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap index 0cd81babcb..af8005ccd3 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap @@ -271,7 +271,7 @@ function baml.llm.Client.build_attempt_with_state(self: baml.llm.Client, planner match (self.client_type : baml.llm.ClientType) : baml.llm.OrchestrationStep[] | never[] ClientType.Primitive => { : baml.llm.OrchestrationStep[] - let resolve_fn = self.get_constructor() : () -> baml.llm.PrimitiveClient + let resolve_fn = self.get_constructor() : () -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument let primitive = resolve_fn() : baml.llm.PrimitiveClient [OrchestrationStep { primitive_client: primitive, delay_ms: 0 }] : baml.llm.OrchestrationStep[] } @@ -397,7 +397,7 @@ function baml.llm.Client.execute_once(self: baml.llm.Client, context: baml.ll match (self.client_type : baml.llm.ClientType) : T ClientType.Primitive => { : T - let resolve_fn = self.get_constructor() : () -> baml.llm.PrimitiveClient + let resolve_fn = self.get_constructor() : () -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument let primitive = resolve_fn() : baml.llm.PrimitiveClient let return_type = get_return_type(context.function_name) : type let prompt = primitive.render_prompt(context.jinja_string, context.args, return_type) : baml.llm.PromptAst diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap index 4f063cb10a..da82f7bcf5 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap @@ -650,7 +650,7 @@ fn testing.run_test(body: () -> null, runner: (() -> testing.TestReport) -> () - } bb1: { - _6 = is_type(copy _2, Function { params: [Function { params: [], ret: Class(TypeName { name: "TestReport", module_path: ["testing"], display_name: "testing.TestReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }], ret: Function { params: [], ret: Class(TypeName { name: "TestReport", module_path: ["testing"], display_name: "testing.TestReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + _6 = is_type(copy _2, Function { params: [Function { params: [], ret: Class(TypeName { name: "TestReport", module_path: ["testing"], display_name: "testing.TestReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }], ret: Function { params: [], ret: Class(TypeName { name: "TestReport", module_path: ["testing"], display_name: "testing.TestReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }, throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); branch copy _6 -> [bb3, bb2]; } @@ -919,7 +919,7 @@ fn testing.run_testset(run_children: () -> testing.TestSetReport, runner: (() -> } bb1: { - _5 = is_type(copy _2, Function { params: [Function { params: [], ret: Class(TypeName { name: "TestSetReport", module_path: ["testing"], display_name: "testing.TestSetReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }], ret: Function { params: [], ret: Class(TypeName { name: "TestSetReport", module_path: ["testing"], display_name: "testing.TestSetReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + _5 = is_type(copy _2, Function { params: [Function { params: [], ret: Class(TypeName { name: "TestSetReport", module_path: ["testing"], display_name: "testing.TestSetReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }], ret: Function { params: [], ret: Class(TypeName { name: "TestSetReport", module_path: ["testing"], display_name: "testing.TestSetReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }, throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); branch copy _5 -> [bb3, bb2]; } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__basic.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__basic.snap new file mode 100644 index 0000000000..fb77051dac --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__basic.snap @@ -0,0 +1,53 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +Word "Test" +Colon ":" +Function "function" +Word "type" +Word "with" +Throws "throws" +Word "clause" +Slash "/" +Slash "/" +Word "This" +Word "verifies" +Word "that" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +Word "parses" +Word "correctly" +Word "type" +Word "ThrowingFn" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +Function "function" +Word "takes_throwing" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__basic.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__basic.snap new file mode 100644 index 0000000000..152ff38598 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__basic.snap @@ -0,0 +1,53 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + TYPE_ALIAS_DEF + WORD "type" + WORD "ThrowingFn" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "takes_throwing" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap new file mode 100644 index 0000000000..93a1634ad8 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -0,0 +1,8 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== HIR2 === +type user.ThrowingFn = () -> int +function user.takes_throwing(f: () -> int) -> int [expr] { + { } f() +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap new file mode 100644 index 0000000000..09017a51b9 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -0,0 +1,21 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== MIR2 === +fn user.takes_throwing(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap new file mode 100644 index 0000000000..269d9a72b5 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -0,0 +1,11 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== TIR2 === +type user.ThrowingFn = () -> int throws string +function user.takes_throwing(f: () -> int throws string) -> int throws never { + { : int + f() : int + } +} +type user.ThrowingFn$stream = unknown diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap new file mode 100644 index 0000000000..40cf564409 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== COMPILER2 DIAGNOSTICS === +No errors found. diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap new file mode 100644 index 0000000000..3af59d72a8 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -0,0 +1,8 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +function user.takes_throwing(f: () -> int) -> int { + load_var f + call_indirect + return +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__basic.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__basic.snap new file mode 100644 index 0000000000..4c2658f291 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__basic.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Unexpected additional element at basic.baml:4:28 diff --git a/baml_language/crates/baml_type/src/lib.rs b/baml_language/crates/baml_type/src/lib.rs index b277d1cda4..4091306899 100644 --- a/baml_language/crates/baml_type/src/lib.rs +++ b/baml_language/crates/baml_type/src/lib.rs @@ -161,10 +161,11 @@ pub enum Ty { // --- Compiler-specific: present in VIR/MIR, absent at runtime --- /// Only recursive aliases survive lower_ty; non-recursive are expanded. TypeAlias(TypeName, TyAttr), - /// Function/arrow type: `(T1, T2, ...) -> R` + /// Function/arrow type: `(T1, T2, ...) -> R [throws E]` Function { params: Vec, ret: Box, + throws: Option>, attr: TyAttr, }, /// Void type — the type of effectful expressions (was VIR `Unit`). @@ -233,7 +234,17 @@ impl Ty { Ty::Union(members, _) => Ty::Union(members, attr), Ty::Opaque(tn, _) => Ty::Opaque(tn, attr), Ty::TypeAlias(tn, _) => Ty::TypeAlias(tn, attr), - Ty::Function { params, ret, .. } => Ty::Function { params, ret, attr }, + Ty::Function { + params, + ret, + throws, + .. + } => Ty::Function { + params, + ret, + throws, + attr, + }, Ty::WatchAccessor(inner, _) => Ty::WatchAccessor(inner, attr), Ty::Future(inner, _) => Ty::Future(inner, attr), } From a0af77cbdc16e0ff5316c5db75ceae2611bc8beb Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 03:21:12 -0500 Subject: [PATCH 02/26] baml_language: wire throws into function signature construction Preserve throw contracts when functions are used as values: free functions, class methods, and builtins now include their declared throws in the constructed Ty::Function. Lambda bodies infer throws from their expressions (or validate against explicit annotations). Adds comprehensive test coverage for lambda inference, explicit annotations, contract violations, nested lambdas, HOF patterns, function declarations, type aliases, and class field function types. --- .../crates/baml_compiler2_tir/src/builder.rs | 171 +++- .../baml_compiler2_tir/src/inference.rs | 14 + .../class_field_fn_throws.baml | 23 + .../function_type_throws/fn_decl_throws.baml | 24 + .../fn_type_alias_throws.baml | 16 + .../function_type_throws/hof_throws.baml | 36 + .../lambda_throws_explicit.baml | 19 + .../lambda_throws_infer.baml | 40 + .../lambda_throws_violation.baml | 7 + .../nested_lambda_throws.baml | 29 + ...rows__01_lexer__class_field_fn_throws.snap | 108 +++ ...type_throws__01_lexer__fn_decl_throws.snap | 141 +++ ...hrows__01_lexer__fn_type_alias_throws.snap | 106 +++ ...ion_type_throws__01_lexer__hof_throws.snap | 209 +++++ ...ows__01_lexer__lambda_throws_explicit.snap | 132 +++ ...throws__01_lexer__lambda_throws_infer.snap | 227 +++++ ...ws__01_lexer__lambda_throws_violation.snap | 57 ++ ...hrows__01_lexer__nested_lambda_throws.snap | 150 ++++ ...ows__02_parser__class_field_fn_throws.snap | 135 +++ ...ype_throws__02_parser__fn_decl_throws.snap | 131 +++ ...rows__02_parser__fn_type_alias_throws.snap | 84 ++ ...on_type_throws__02_parser__hof_throws.snap | 246 +++++ ...ws__02_parser__lambda_throws_explicit.snap | 125 +++ ...hrows__02_parser__lambda_throws_infer.snap | 257 ++++++ ...s__02_parser__lambda_throws_violation.snap | 51 ++ ...rows__02_parser__nested_lambda_throws.snap | 184 ++++ ...l_tests__function_type_throws__03_hir.snap | 90 ++ ...tests__function_type_throws__04_5_mir.snap | 838 ++++++++++++++++++ ...l_tests__function_type_throws__04_tir.snap | 309 +++++++ ..._function_type_throws__05_diagnostics.snap | 40 +- ...sts__function_type_throws__06_codegen.snap | 184 ++++ ...__10_formatter__class_field_fn_throws.snap | 5 + ..._throws__10_formatter__fn_decl_throws.snap | 5 + ...s__10_formatter__fn_type_alias_throws.snap | 5 + ...type_throws__10_formatter__hof_throws.snap | 5 + ..._10_formatter__lambda_throws_explicit.snap | 28 + ...ws__10_formatter__lambda_throws_infer.snap | 55 ++ ...10_formatter__lambda_throws_violation.snap | 12 + ...s__10_formatter__nested_lambda_throws.snap | 38 + .../baml_tests__lambda_basic__04_tir.snap | 4 +- 40 files changed, 4311 insertions(+), 29 deletions(-) create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/class_field_fn_throws.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/fn_decl_throws.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/fn_type_alias_throws.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_explicit.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_infer.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_violation.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/nested_lambda_throws.baml create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__class_field_fn_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_decl_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_type_alias_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_explicit.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_infer.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_violation.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__nested_lambda_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__class_field_fn_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_decl_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_type_alias_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_explicit.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_infer.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_violation.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__nested_lambda_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__class_field_fn_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_decl_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_type_alias_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_explicit.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_infer.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_violation.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__nested_lambda_throws.snap diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index aed58bd3b9..af81ae3f1f 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -715,19 +715,37 @@ impl<'db> TypeInferenceBuilder<'db> { }); // Infer the lambda body using save/restore approach - let (ret_ty, _lambda_expressions) = self.infer_lambda_body( + let (ret_ty, e_body, _lambda_expressions) = self.infer_lambda_body( func_def, ¶m_tys, return_annotation.as_ref(), expr_id, ); + // Determine throws: explicit annotation takes precedence, + // otherwise infer from the body. + let throws_ty = if let Some(te) = &func_def.throws { + let mut diags = Vec::new(); + let declared = crate::lower_type_expr::lower_type_expr_in_ns( + self.context.db(), + &te.expr, + self.package_items, + &self.ns_context, + &self.generic_params, + &mut diags, + ); + for diag in diags { + self.context.report_at_span(diag, te.span); + } + declared + } else { + e_body + }; + Ty::Function { params: param_tys, ret: Box::new(ret_ty), - throws: Box::new(Ty::Never { - attr: TyAttr::default(), - }), + throws: Box::new(throws_ty), attr: TyAttr::default(), } } @@ -1088,6 +1106,7 @@ impl<'db> TypeInferenceBuilder<'db> { Ty::Function { params: expected_params, ret: expected_ret, + throws: expected_throws, .. } => { // Checking mode: decompose expected function type. @@ -1167,19 +1186,38 @@ impl<'db> TypeInferenceBuilder<'db> { return_annotation.as_ref().unwrap_or(expected_ret.as_ref()); // Infer/check the lambda body using save/restore approach - let (ret_ty, _lambda_expressions) = self.infer_lambda_body( + let (ret_ty, e_body, _lambda_expressions) = self.infer_lambda_body( func_def, ¶m_tys, Some(effective_ret), expr_id, ); + // Determine throws: explicit annotation > expected > inferred from body + let throws_ty = if let Some(te) = &func_def.throws { + let mut diags = Vec::new(); + let declared = crate::lower_type_expr::lower_type_expr_in_ns( + self.context.db(), + &te.expr, + self.package_items, + &self.ns_context, + &self.generic_params, + &mut diags, + ); + for diag in diags { + self.context.report_at_span(diag, te.span); + } + declared + } else { + // No annotation: infer from body, use expected_throws for TypeVar binding + let _ = expected_throws; + e_body + }; + let result = Ty::Function { params: param_tys, ret: Box::new(ret_ty), - throws: Box::new(Ty::Never { - attr: TyAttr::default(), - }), + throws: Box::new(throws_ty), attr: TyAttr::default(), }; self.record_expr_type(expr_id, result.clone()); @@ -2503,16 +2541,27 @@ impl<'db> TypeInferenceBuilder<'db> { out.extend(transitive.iter().cloned()); } else { // Target not in throw set registry (function parameter, - // external function, etc.) — conservatively assume it - // could throw anything. + // external function, etc.). Check if the callee has a + // type-level throws annotation before falling back to Unknown. + if let Some(Ty::Function { throws, .. }) = self.expressions.get(callee) { + let facts = crate::throw_inference::flatten_ty_to_facts(throws); + out.extend(facts); + } else { + out.insert(Ty::Unknown { + attr: TyAttr::default(), + }); + } + } + } else { + // No resolvable target name — check type-level throws. + if let Some(Ty::Function { throws, .. }) = self.expressions.get(callee) { + let facts = crate::throw_inference::flatten_ty_to_facts(throws); + out.extend(facts); + } else { out.insert(Ty::Unknown { attr: TyAttr::default(), }); } - } else { - out.insert(Ty::Unknown { - attr: TyAttr::default(), - }); } } Expr::Catch { clauses, .. } => { @@ -3048,9 +3097,23 @@ impl<'db> TypeInferenceBuilder<'db> { attr: TyAttr::default(), }), ), - throws: Box::new(Ty::Never { - attr: TyAttr::default(), - }), + throws: Box::new( + sig.throws + .as_ref() + .map(|te| { + crate::lower_type_expr::lower_type_expr_in_ns( + db, + te, + pkg_items, + &ns_context, + generic_params, + &mut diags, + ) + }) + .unwrap_or(Ty::Never { + attr: TyAttr::default(), + }), + ), attr: TyAttr::default(), }; return Some(ty); @@ -3139,9 +3202,23 @@ impl<'db> TypeInferenceBuilder<'db> { attr: TyAttr::default(), }), ), - throws: Box::new(Ty::Never { - attr: TyAttr::default(), - }), + throws: Box::new( + sig.throws + .as_ref() + .map(|te| { + crate::lower_type_expr::lower_type_expr_in_ns( + db, + te, + self.package_items, + &sig_ns, + generic_params, + &mut diags, + ) + }) + .unwrap_or(Ty::Never { + attr: TyAttr::default(), + }), + ), attr: TyAttr::default(), } } @@ -3647,9 +3724,23 @@ impl<'db> TypeInferenceBuilder<'db> { attr: TyAttr::default(), }), ), - throws: Box::new(Ty::Never { - attr: TyAttr::default(), - }), + throws: Box::new( + sig.throws + .as_ref() + .map(|te| { + crate::lower_type_expr::lower_type_expr_in_ns( + db, + te, + pkg_items_for_class, + &ns_context, + &all_generic_params, + &mut diags, + ) + }) + .unwrap_or(Ty::Never { + attr: TyAttr::default(), + }), + ), attr: TyAttr::default(), }; // Note: diags from method signatures are reported at definition site. @@ -4882,6 +4973,22 @@ impl<'db> TypeInferenceBuilder<'db> { } } + /// Collapse a set of throw fact types into a single `Ty`. + /// + /// - Empty set → `Ty::Never` (pure, no throws) + /// - Single element → that element + /// - Multiple elements → `Ty::Union` + fn collapse_throw_set(set: std::collections::BTreeSet) -> Ty { + let mut members: Vec = set.into_iter().collect(); + match members.len() { + 0 => Ty::Never { + attr: TyAttr::default(), + }, + 1 => members.remove(0), + _ => Ty::Union(members, TyAttr::default()), + } + } + /// Infer/check a lambda body using a save/restore approach. /// /// Saves the current locals, `declared_types`, `declared_return_ty`, @@ -4889,7 +4996,8 @@ impl<'db> TypeInferenceBuilder<'db> { /// the lambda's arena and the parent's arena). After inference, restores all /// saved state and returns the lambda's expression types separately. /// - /// Returns `(inferred_return_ty, lambda_expressions)` where + /// Returns `(inferred_return_ty, effective_throws_ty, lambda_expressions)` where + /// `effective_throws_ty` is the inferred throws effect from the lambda body, and /// `lambda_expressions` contains the expression types for the lambda body /// only (keyed by the lambda's own `ExprId`s, which start at 0). pub fn infer_lambda_body( @@ -4898,7 +5006,7 @@ impl<'db> TypeInferenceBuilder<'db> { param_tys: &[(Option, Ty)], expected_ret: Option<&Ty>, _lambda_expr_id: ExprId, - ) -> (Ty, FxHashMap) { + ) -> (Ty, Ty, FxHashMap) { use baml_compiler2_ast::FunctionBodyDef; // Get the lambda's ExprBody @@ -4907,6 +5015,9 @@ impl<'db> TypeInferenceBuilder<'db> { Ty::Unknown { attr: TyAttr::default(), }, + Ty::Never { + attr: TyAttr::default(), + }, FxHashMap::default(), ); }; @@ -4916,6 +5027,9 @@ impl<'db> TypeInferenceBuilder<'db> { Ty::Void { attr: TyAttr::default(), }, + Ty::Never { + attr: TyAttr::default(), + }, FxHashMap::default(), ); }; @@ -4993,6 +5107,11 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expr(root_expr, lambda_body) }; + // Collect effective throws from the lambda body BEFORE restoring state, + // since collect_effective_throws uses catch_residual_throws and expressions. + let effective_throws_set = self.collect_effective_throws(lambda_body); + let effective_throws_ty = Self::collapse_throw_set(effective_throws_set); + // Collect the lambda's expression types and restore parent state let lambda_expressions = std::mem::replace(&mut self.expressions, saved_expressions); self.bindings = saved_bindings; @@ -5004,7 +5123,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.declared_return_ty = saved_return_ty; self.generic_params = saved_generic_params; - (ret_ty, lambda_expressions) + (ret_ty, effective_throws_ty, lambda_expressions) } } diff --git a/baml_language/crates/baml_compiler2_tir/src/inference.rs b/baml_language/crates/baml_compiler2_tir/src/inference.rs index 2f0b0fb5da..93566486e7 100644 --- a/baml_language/crates/baml_compiler2_tir/src/inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/inference.rs @@ -478,6 +478,13 @@ pub fn infer_scope_types<'db>( if let Some(root_expr) = lambda_body.root_expr { builder.infer_expr(root_expr, lambda_body); } + // Validate declared `throws` against effective escaping throws. + builder.check_throws_contract( + lambda_body, + func_def.throws.as_ref().map(|te| &te.expr), + func_def.throws.as_ref().map(|te| te.span), + func_def.span, + ); } } break 'ancestor_walk; @@ -527,6 +534,13 @@ pub fn infer_scope_types<'db>( if let Some(root_expr) = lambda_body.root_expr { builder.infer_expr(root_expr, lambda_body); } + // Validate declared `throws` against effective escaping throws. + builder.check_throws_contract( + lambda_body, + func_def.throws.as_ref().map(|te| &te.expr), + func_def.throws.as_ref().map(|te| te.span), + func_def.span, + ); } } break 'ancestor_walk; diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/class_field_fn_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/class_field_fn_throws.baml new file mode 100644 index 0000000000..cc1489a0c0 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/class_field_fn_throws.baml @@ -0,0 +1,23 @@ +// === Class fields with function types and throws === + +class PureHandler { + run: () -> null +} + +class ThrowingHandler { + run: () -> null throws string +} + +class MixedHandler { + safe: () -> int + risky: () -> int throws string +} + +// Function returning a class with function fields +function make_pure_handler() -> PureHandler { + PureHandler { run: () -> null { null } } +} + +function make_throwing_handler() -> ThrowingHandler { + ThrowingHandler { run: () -> null { throw "error" } } +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/fn_decl_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/fn_decl_throws.baml new file mode 100644 index 0000000000..edfe691922 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/fn_decl_throws.baml @@ -0,0 +1,24 @@ +// === Function declaration throws preserved in type === + +// Function with explicit throws - should show throws string in type +function may_fail(x: int) -> string throws string { + if (x == 0) { throw "zero" } + "ok" +} + +// Function with no throws - should show throws never +function always_ok(x: int) -> string { + "always ok" +} + +// Function calling a throwing function - should propagate throws +function caller() -> string { + may_fail(1) +} + +// Function calling a throwing function with catch - should not propagate +function safe_caller() -> string { + may_fail(1) catch (e) { + _ => "caught" + } +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/fn_type_alias_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/fn_type_alias_throws.baml new file mode 100644 index 0000000000..29091e4fbc --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/fn_type_alias_throws.baml @@ -0,0 +1,16 @@ +// === Function type aliases with throws === + +// Type alias with explicit throws +type ThrowingCallback = () -> int throws string + +// Type alias without throws - defaults to throws never +type PureCallback = () -> int + +// Type alias with throws never explicitly +type ExplicitPure = () -> int throws never + +// Parameterized function type with throws +type Mapper = (int) -> string throws string + +// Nested function type - outer throws, inner pure +type Wrapper = (() -> int) -> int throws string diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml new file mode 100644 index 0000000000..94b0004d1f --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml @@ -0,0 +1,36 @@ +// === Higher-order function throws propagation === + +// Function taking a callback and calling it +function apply(f: () -> int) -> int { + f() +} + +// Calling apply with a pure lambda +function test_apply_pure() -> int { + apply(() -> int { 42 }) +} + +// Calling apply with a throwing lambda +function test_apply_throwing() -> int { + apply(() -> int { throw "boom" }) +} + +// Function taking callback with explicit throws +function apply_throwing(f: () -> int throws string) -> int { + f() +} + +function test_apply_explicit_throws() -> int { + apply_throwing(() -> int { throw "boom" }) +} + +// Function that takes a callback and calls it, with its own throw +function apply_and_throw(f: () -> int) -> int { + let result = f() + if (result < 0) { throw "negative result" } + result +} + +function test_apply_and_throw_pure() -> int { + apply_and_throw(() -> int { 42 }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_explicit.baml b/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_explicit.baml new file mode 100644 index 0000000000..2114699ddc --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_explicit.baml @@ -0,0 +1,19 @@ +// === Lambda explicit throws annotations === + +// Lambda with explicit throws string - body matches +function test_explicit_throws_match() -> int { + let f = () -> int throws string { throw "boom" } + f() +} + +// Lambda with explicit throws never - body is pure (should be fine) +function test_explicit_throws_never_pure() -> int { + let f = () -> int throws never { 42 } + f() +} + +// Lambda with explicit throws string - body is pure (should be fine, never <: string) +function test_explicit_throws_wider_than_body() -> int { + let f = () -> int throws string { 42 } + f() +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_infer.baml b/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_infer.baml new file mode 100644 index 0000000000..05fa0d8843 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_infer.baml @@ -0,0 +1,40 @@ +// === Lambda throws inference === + +// Pure lambda - should infer throws: never +function test_pure_lambda() -> int { + let f = () -> int { 42 } + f() +} + +// Throwing lambda - should infer throws: string +function test_throwing_lambda() -> int { + let f = () -> int { throw "boom" } + f() +} + +// Lambda throwing int - should infer throws: int +function test_throwing_int() -> string { + let f = () -> string { throw 42 } + f() +} + +// Lambda with conditional throw - only one branch throws +function test_conditional_throw(x: int) -> int { + let f = (n: int) -> int { + if (n < 0) { throw "negative" } + n + } + f(x) +} + +// Lambda with multiple throw types - should infer throws: string | int +function test_multi_throw_types(x: int) -> string { + let f = (n: int) -> string { + match (n) { + 0 => throw "string error", + 1 => throw 1, + _ => "ok" + } + } + f(x) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_violation.baml b/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_violation.baml new file mode 100644 index 0000000000..85fcfff16e --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_violation.baml @@ -0,0 +1,7 @@ +// === Lambda throws contract violations (should produce errors) === + +// Lambda with explicit throws never but body throws - should error +function test_throws_never_but_throws() -> int { + let f = () -> int throws never { throw "boom" } + f() +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/nested_lambda_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/nested_lambda_throws.baml new file mode 100644 index 0000000000..526251fef9 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/nested_lambda_throws.baml @@ -0,0 +1,29 @@ +// === Nested lambda throws behavior === + +// Inner lambda throws, outer doesn't throw directly +function test_nested_inner_throws() -> int { + let outer = () -> int { + let inner = () -> int { throw "inner boom" } + inner() + } + outer() +} + +// Outer lambda throws, inner is pure +function test_nested_outer_throws() -> int { + let outer = () -> int { + let inner = () -> int { 42 } + throw "outer boom" + } + outer() +} + +// Both throw different types +function test_nested_both_throw() -> int { + let outer = () -> int { + let inner = () -> int { throw 42 } + inner() + throw "outer" + } + outer() +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__class_field_fn_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__class_field_fn_throws.snap new file mode 100644 index 0000000000..02b0010e06 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__class_field_fn_throws.snap @@ -0,0 +1,108 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Class" +Word "fields" +Word "with" +Function "function" +Word "types" +Word "and" +Throws "throws" +EqualsEquals "==" +Equals "=" +Class "class" +Word "PureHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +RBrace "}" +Class "class" +Word "ThrowingHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Throws "throws" +Word "string" +RBrace "}" +Class "class" +Word "MixedHandler" +LBrace "{" +Word "safe" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Word "risky" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +RBrace "}" +Slash "/" +Slash "/" +Word "Function" +Word "returning" +Word "a" +Class "class" +Word "with" +Function "function" +Word "fields" +Function "function" +Word "make_pure_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "PureHandler" +LBrace "{" +Word "PureHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "null" +RBrace "}" +RBrace "}" +RBrace "}" +Function "function" +Word "make_throwing_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "ThrowingHandler" +LBrace "{" +Word "ThrowingHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Quote "\"" +Word "error" +Quote "\"" +RBrace "}" +RBrace "}" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_decl_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_decl_throws.snap new file mode 100644 index 0000000000..7a1084d342 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_decl_throws.snap @@ -0,0 +1,141 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Function" +Word "declaration" +Throws "throws" +Word "preserved" +In "in" +Word "type" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Function" +Word "with" +Word "explicit" +Throws "throws" +Minus "-" +Word "should" +Word "show" +Throws "throws" +Word "string" +In "in" +Word "type" +Function "function" +Word "may_fail" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +Throws "throws" +Word "string" +LBrace "{" +If "if" +LParen "(" +Word "x" +EqualsEquals "==" +IntegerLiteral "0" +RParen ")" +LBrace "{" +Throw "throw" +Quote "\"" +Word "zero" +Quote "\"" +RBrace "}" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +Slash "/" +Slash "/" +Word "Function" +Word "with" +Word "no" +Throws "throws" +Minus "-" +Word "should" +Word "show" +Throws "throws" +Word "never" +Function "function" +Word "always_ok" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Quote "\"" +Word "always" +Word "ok" +Quote "\"" +RBrace "}" +Slash "/" +Slash "/" +Word "Function" +Word "calling" +Word "a" +Word "throwing" +Function "function" +Minus "-" +Word "should" +Word "propagate" +Throws "throws" +Function "function" +Word "caller" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "may_fail" +LParen "(" +IntegerLiteral "1" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Function" +Word "calling" +Word "a" +Word "throwing" +Function "function" +Word "with" +Catch "catch" +Minus "-" +Word "should" +Word "not" +Word "propagate" +Function "function" +Word "safe_caller" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "may_fail" +LParen "(" +IntegerLiteral "1" +RParen ")" +Catch "catch" +LParen "(" +Word "e" +RParen ")" +LBrace "{" +Word "_" +FatArrow "=>" +Quote "\"" +Word "caught" +Quote "\"" +RBrace "}" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_type_alias_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_type_alias_throws.snap new file mode 100644 index 0000000000..3ee4e07410 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_type_alias_throws.snap @@ -0,0 +1,106 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Function" +Word "type" +Word "aliases" +Word "with" +Throws "throws" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Type" +Word "alias" +Word "with" +Word "explicit" +Throws "throws" +Word "type" +Word "ThrowingCallback" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +Slash "/" +Slash "/" +Word "Type" +Word "alias" +Word "without" +Throws "throws" +Minus "-" +Word "defaults" +Word "to" +Throws "throws" +Word "never" +Word "type" +Word "PureCallback" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Slash "/" +Slash "/" +Word "Type" +Word "alias" +Word "with" +Throws "throws" +Word "never" +Word "explicitly" +Word "type" +Word "ExplicitPure" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "never" +Slash "/" +Slash "/" +Word "Parameterized" +Function "function" +Word "type" +Word "with" +Throws "throws" +Word "type" +Word "Mapper" +Equals "=" +LParen "(" +Word "int" +RParen ")" +Arrow "->" +Word "string" +Throws "throws" +Word "string" +Slash "/" +Slash "/" +Word "Nested" +Function "function" +Word "type" +Minus "-" +Word "outer" +Throws "throws" +Comma "," +Word "inner" +Word "pure" +Word "type" +Word "Wrapper" +Equals "=" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap new file mode 100644 index 0000000000..73a4493d97 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap @@ -0,0 +1,209 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Higher-order" +Function "function" +Throws "throws" +Word "propagation" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Function" +Word "taking" +Word "a" +Word "callback" +Word "and" +Word "calling" +Word "it" +Function "function" +Word "apply" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Calling" +Word "apply" +Word "with" +Word "a" +Word "pure" +Word "lambda" +Function "function" +Word "test_apply_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Calling" +Word "apply" +Word "with" +Word "a" +Word "throwing" +Word "lambda" +Function "function" +Word "test_apply_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "boom" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Function" +Word "taking" +Word "callback" +Word "with" +Word "explicit" +Throws "throws" +Function "function" +Word "apply_throwing" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_apply_explicit_throws" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_throwing" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "boom" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Function" +Word "that" +Word "takes" +Word "a" +Word "callback" +Word "and" +Word "calls" +Word "it" +Comma "," +Word "with" +Word "its" +Word "own" +Throw "throw" +Function "function" +Word "apply_and_throw" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "result" +Equals "=" +Word "f" +LParen "(" +RParen ")" +If "if" +LParen "(" +Word "result" +Less "<" +IntegerLiteral "0" +RParen ")" +LBrace "{" +Throw "throw" +Quote "\"" +Word "negative" +Word "result" +Quote "\"" +RBrace "}" +Word "result" +RBrace "}" +Function "function" +Word "test_apply_and_throw_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_and_throw" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_explicit.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_explicit.snap new file mode 100644 index 0000000000..5148bb11c7 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_explicit.snap @@ -0,0 +1,132 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Lambda" +Word "explicit" +Throws "throws" +Word "annotations" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Lambda" +Word "with" +Word "explicit" +Throws "throws" +Word "string" +Minus "-" +Word "body" +Word "matches" +Function "function" +Word "test_explicit_throws_match" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +LBrace "{" +Throw "throw" +Quote "\"" +Word "boom" +Quote "\"" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Lambda" +Word "with" +Word "explicit" +Throws "throws" +Word "never" +Minus "-" +Word "body" +Word "is" +Word "pure" +LParen "(" +Word "should" +Word "be" +Word "fine" +RParen ")" +Function "function" +Word "test_explicit_throws_never_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "never" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Lambda" +Word "with" +Word "explicit" +Throws "throws" +Word "string" +Minus "-" +Word "body" +Word "is" +Word "pure" +LParen "(" +Word "should" +Word "be" +Word "fine" +Comma "," +Word "never" +Less "<" +Colon ":" +Word "string" +RParen ")" +Function "function" +Word "test_explicit_throws_wider_than_body" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_infer.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_infer.snap new file mode 100644 index 0000000000..2c165fd40d --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_infer.snap @@ -0,0 +1,227 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Lambda" +Throws "throws" +Word "inference" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Pure" +Word "lambda" +Minus "-" +Word "should" +Word "infer" +Throws "throws" +Colon ":" +Word "never" +Function "function" +Word "test_pure_lambda" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Throwing" +Word "lambda" +Minus "-" +Word "should" +Word "infer" +Throws "throws" +Colon ":" +Word "string" +Function "function" +Word "test_throwing_lambda" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "boom" +Quote "\"" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Lambda" +Word "throwing" +Word "int" +Minus "-" +Word "should" +Word "infer" +Throws "throws" +Colon ":" +Word "int" +Function "function" +Word "test_throwing_int" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Lambda" +Word "with" +Word "conditional" +Throw "throw" +Minus "-" +Word "only" +Word "one" +Word "branch" +Throws "throws" +Function "function" +Word "test_conditional_throw" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +Word "n" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +If "if" +LParen "(" +Word "n" +Less "<" +IntegerLiteral "0" +RParen ")" +LBrace "{" +Throw "throw" +Quote "\"" +Word "negative" +Quote "\"" +RBrace "}" +Word "n" +RBrace "}" +Word "f" +LParen "(" +Word "x" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Lambda" +Word "with" +Word "multiple" +Throw "throw" +Word "types" +Minus "-" +Word "should" +Word "infer" +Throws "throws" +Colon ":" +Word "string" +Pipe "|" +Word "int" +Function "function" +Word "test_multi_throw_types" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +Word "n" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Match "match" +LParen "(" +Word "n" +RParen ")" +LBrace "{" +IntegerLiteral "0" +FatArrow "=>" +Throw "throw" +Quote "\"" +Word "string" +Word "error" +Quote "\"" +Comma "," +IntegerLiteral "1" +FatArrow "=>" +Throw "throw" +IntegerLiteral "1" +Comma "," +Word "_" +FatArrow "=>" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +RBrace "}" +Word "f" +LParen "(" +Word "x" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_violation.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_violation.snap new file mode 100644 index 0000000000..485cddc227 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_violation.snap @@ -0,0 +1,57 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Lambda" +Throws "throws" +Word "contract" +Word "violations" +LParen "(" +Word "should" +Word "produce" +Word "errors" +RParen ")" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Lambda" +Word "with" +Word "explicit" +Throws "throws" +Word "never" +Word "but" +Word "body" +Throws "throws" +Minus "-" +Word "should" +Word "error" +Function "function" +Word "test_throws_never_but_throws" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "never" +LBrace "{" +Throw "throw" +Quote "\"" +Word "boom" +Quote "\"" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__nested_lambda_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__nested_lambda_throws.snap new file mode 100644 index 0000000000..705fa5965b --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__nested_lambda_throws.snap @@ -0,0 +1,150 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Nested" +Word "lambda" +Throws "throws" +Word "behavior" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Inner" +Word "lambda" +Throws "throws" +Comma "," +Word "outer" +Word "doesn" +Error "'" +Word "t" +Throw "throw" +Word "directly" +Function "function" +Word "test_nested_inner_throws" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "outer" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "inner" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "inner" +Word "boom" +Quote "\"" +RBrace "}" +Word "inner" +LParen "(" +RParen ")" +RBrace "}" +Word "outer" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Outer" +Word "lambda" +Throws "throws" +Comma "," +Word "inner" +Word "is" +Word "pure" +Function "function" +Word "test_nested_outer_throws" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "outer" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "inner" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +Throw "throw" +Quote "\"" +Word "outer" +Word "boom" +Quote "\"" +RBrace "}" +Word "outer" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Both" +Throw "throw" +Word "different" +Word "types" +Function "function" +Word "test_nested_both_throw" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "outer" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "inner" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +Word "inner" +LParen "(" +RParen ")" +Throw "throw" +Quote "\"" +Word "outer" +Quote "\"" +RBrace "}" +Word "outer" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__class_field_fn_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__class_field_fn_throws.snap new file mode 100644 index 0000000000..cbeebc1074 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__class_field_fn_throws.snap @@ -0,0 +1,135 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + CLASS_DEF + KW_CLASS "class" + WORD "PureHandler" + L_BRACE "{" + FIELD + WORD "run" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + R_BRACE "}" + CLASS_DEF + KW_CLASS "class" + WORD "ThrowingHandler" + L_BRACE "{" + FIELD + WORD "run" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + R_BRACE "}" + CLASS_DEF + KW_CLASS "class" + WORD "MixedHandler" + L_BRACE "{" + FIELD + WORD "safe" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + FIELD + WORD "risky" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_pure_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "PureHandler" + WORD "PureHandler" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OBJECT_LITERAL + WORD "PureHandler" + L_BRACE "{" + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR "{ null }" + L_BRACE "{" + WORD "null" + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_throwing_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "ThrowingHandler" + WORD "ThrowingHandler" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OBJECT_LITERAL + WORD "ThrowingHandler" + L_BRACE "{" + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "error" + QUOTE """ + WORD "error" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_decl_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_decl_throws.snap new file mode 100644 index 0000000000..517de1bd1f --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_decl_throws.snap @@ -0,0 +1,131 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "may_fail" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "x == 0" + WORD "x" + EQUALS_EQUALS "==" + INTEGER_LITERAL "0" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "zero" + QUOTE """ + WORD "zero" + QUOTE """ + R_BRACE "}" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "always_ok" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + STRING_LITERAL "always ok" + QUOTE """ + WORD "always" + WORD "ok" + QUOTE """ + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "caller" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "may_fail" + CALL_ARGS "(1)" + L_PAREN "(" + INTEGER_LITERAL "1" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "safe_caller" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CATCH_EXPR + CALL_EXPR + WORD "may_fail" + CALL_ARGS "(1)" + L_PAREN "(" + INTEGER_LITERAL "1" + R_PAREN ")" + CATCH_CLAUSE + KW_CATCH "catch" + L_PAREN "(" + CATCH_PATTERN "e" + WORD "e" + R_PAREN ")" + L_BRACE "{" + CATCH_ARM + CATCH_PATTERN "_" + WORD "_" + FAT_ARROW "=>" + STRING_LITERAL "caught" + QUOTE """ + WORD "caught" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_type_alias_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_type_alias_throws.snap new file mode 100644 index 0000000000..09c1501ea9 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_type_alias_throws.snap @@ -0,0 +1,84 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + TYPE_ALIAS_DEF + WORD "type" + WORD "ThrowingCallback" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + TYPE_ALIAS_DEF + WORD "type" + WORD "PureCallback" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + TYPE_ALIAS_DEF + WORD "type" + WORD "ExplicitPure" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "never" + WORD "never" + TYPE_ALIAS_DEF + WORD "type" + WORD "Mapper" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + TYPE_ALIAS_DEF + WORD "type" + WORD "Wrapper" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap new file mode 100644 index 0000000000..1c7e5a3e33 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap @@ -0,0 +1,246 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_apply_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_apply_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "boom" + QUOTE """ + WORD "boom" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_throwing" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_apply_explicit_throws" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_throwing" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "boom" + QUOTE """ + WORD "boom" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_and_throw" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "result" + EQUALS "=" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "result < 0" + WORD "result" + LESS "<" + INTEGER_LITERAL "0" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "negative result" + QUOTE """ + WORD "negative" + WORD "result" + QUOTE """ + R_BRACE "}" + WORD "result" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_apply_and_throw_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_and_throw" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_explicit.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_explicit.snap new file mode 100644 index 0000000000..f3a5808c6a --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_explicit.snap @@ -0,0 +1,125 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_explicit_throws_match" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "boom" + QUOTE """ + WORD "boom" + QUOTE """ + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_explicit_throws_never_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "never" + WORD "never" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_explicit_throws_wider_than_body" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_infer.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_infer.snap new file mode 100644 index 0000000000..cb2b9e0530 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_infer.snap @@ -0,0 +1,257 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_pure_lambda" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_throwing_lambda" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "boom" + QUOTE """ + WORD "boom" + QUOTE """ + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_throwing_int" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_conditional_throw" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "n" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "n < 0" + WORD "n" + LESS "<" + INTEGER_LITERAL "0" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "negative" + QUOTE """ + WORD "negative" + QUOTE """ + R_BRACE "}" + WORD "n" + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "(x)" + L_PAREN "(" + WORD "x" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_multi_throw_types" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "n" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR + L_BRACE "{" + MATCH_EXPR + KW_MATCH "match" + L_PAREN "(" + WORD "n" + R_PAREN ")" + L_BRACE "{" + MATCH_ARM + MATCH_PATTERN "0" + INTEGER_LITERAL "0" + FAT_ARROW "=>" + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "string error" + QUOTE """ + WORD "string" + WORD "error" + QUOTE """ + COMMA "," + MATCH_ARM + MATCH_PATTERN "1" + INTEGER_LITERAL "1" + FAT_ARROW "=>" + THROW_EXPR "throw 1" + KW_THROW "throw" + INTEGER_LITERAL "1" + COMMA "," + MATCH_ARM + MATCH_PATTERN "_" + WORD "_" + FAT_ARROW "=>" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "(x)" + L_PAREN "(" + WORD "x" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_violation.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_violation.snap new file mode 100644 index 0000000000..e677bdcdfa --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_violation.snap @@ -0,0 +1,51 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_throws_never_but_throws" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "never" + WORD "never" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "boom" + QUOTE """ + WORD "boom" + QUOTE """ + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__nested_lambda_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__nested_lambda_throws.snap new file mode 100644 index 0000000000..f30e7f1514 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__nested_lambda_throws.snap @@ -0,0 +1,184 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_nested_inner_throws" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "outer" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "inner" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "inner boom" + QUOTE """ + WORD "inner" + WORD "boom" + QUOTE """ + R_BRACE "}" + CALL_EXPR + WORD "inner" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + CALL_EXPR + WORD "outer" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_nested_outer_throws" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "outer" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "inner" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "outer boom" + QUOTE """ + WORD "outer" + WORD "boom" + QUOTE """ + R_BRACE "}" + CALL_EXPR + WORD "outer" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_nested_both_throw" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "outer" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "inner" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + CALL_EXPR + WORD "inner" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "outer" + QUOTE """ + WORD "outer" + QUOTE """ + R_BRACE "}" + CALL_EXPR + WORD "outer" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index 93a1634ad8..6884cbbc6a 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -6,3 +6,93 @@ type user.ThrowingFn = () -> int function user.takes_throwing(f: () -> int) -> int [expr] { { } f() } +class user.MixedHandler { + safe: () -> int + risky: () -> int +} +class user.PureHandler { + run: () -> null +} +class user.ThrowingHandler { + run: () -> null +} +function user.make_pure_handler() -> user.PureHandler [expr] { + { } user.PureHandler { run: () -> null { { } null } } +} +function user.make_throwing_handler() -> user.ThrowingHandler [expr] { + { } user.ThrowingHandler { run: () -> null { { throw "error" } } } +} +function user.always_ok(x: int) -> string [expr] { + { } "always ok" +} +function user.caller() -> string [expr] { + { } may_fail(1) +} +function user.may_fail(x: int) -> string [expr] { + { if (x Eq 0) { throw "zero" } } "ok" +} +function user.safe_caller() -> string [expr] { + { } may_fail(1) catch (e) { _ => "caught" } +} +type user.ExplicitPure = () -> int +type user.Mapper = (int) -> string +type user.PureCallback = () -> int +type user.ThrowingCallback = () -> int +type user.Wrapper = (() -> int) -> int +function user.apply(f: () -> int) -> int [expr] { + { } f() +} +function user.apply_and_throw(f: () -> int) -> int [expr] { + { let result = f(); if (result Lt 0) { throw "negative result" } } result +} +function user.apply_throwing(f: () -> int) -> int [expr] { + { } f() +} +function user.test_apply_and_throw_pure() -> int [expr] { + { } apply_and_throw(() -> int { { } 42 }) +} +function user.test_apply_explicit_throws() -> int [expr] { + { } apply_throwing(() -> int { { throw "boom" } }) +} +function user.test_apply_pure() -> int [expr] { + { } apply(() -> int { { } 42 }) +} +function user.test_apply_throwing() -> int [expr] { + { } apply(() -> int { { throw "boom" } }) +} +function user.test_explicit_throws_match() -> int [expr] { + { let f = () -> int throws string { { throw "boom" } } } f() +} +function user.test_explicit_throws_never_pure() -> int [expr] { + { let f = () -> int throws never { { } 42 } } f() +} +function user.test_explicit_throws_wider_than_body() -> int [expr] { + { let f = () -> int throws string { { } 42 } } f() +} +function user.test_conditional_throw(x: int) -> int [expr] { + { let f = (n: int) -> int { { if (n Lt 0) { throw "negative" } } n } } f(x) +} +function user.test_multi_throw_types(x: int) -> string [expr] { + { let f = (n: int) -> string { { } match (n) { 0 => throw "string error", 1 => throw 1, _ => "ok" } } } f(x) +} +function user.test_pure_lambda() -> int [expr] { + { let f = () -> int { { } 42 } } f() +} +function user.test_throwing_int() -> string [expr] { + { let f = () -> string { { throw 42 } } } f() +} +function user.test_throwing_lambda() -> int [expr] { + { let f = () -> int { { throw "boom" } } } f() +} +function user.test_throws_never_but_throws() -> int [expr] { + { let f = () -> int throws never { { throw "boom" } } } f() +} +function user.test_nested_both_throw() -> int [expr] { + { let outer = () -> int { { let inner = () -> int { { throw 42 } }; inner(); throw "outer" } } } outer() +} +function user.test_nested_inner_throws() -> int [expr] { + { let outer = () -> int { { let inner = () -> int { { throw "inner boom" } } } inner() } } outer() +} +function user.test_nested_outer_throws() -> int [expr] { + { let outer = () -> int { { let inner = () -> int { { } 42 }; throw "outer boom" } } } outer() +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index 09017a51b9..31f87a3c1f 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -19,3 +19,841 @@ fn user.takes_throwing(f: () -> int) -> int { return; } } + +fn user.make_pure_handler() -> PureHandler { + // Locals: + let _0: PureHandler // _0 // return + let _1: () -> null + + bb0: { + _1 = make_closure lambda[0](); + _0 = PureHandler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.make_throwing_handler() -> ThrowingHandler { + // Locals: + let _0: ThrowingHandler // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = ThrowingHandler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "error"; + } +} + +fn user.always_ok(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + + bb0: { + _0 = const "always ok"; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.caller() -> string { + // Locals: + let _0: string // _0 // return + + bb0: { + _0 = call const fn user.may_fail(const 1_i64) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.may_fail(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: bool + + bb0: { + _2 = copy _1 == const 0_i64; + branch copy _2 -> [bb4, bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _0 = const "ok"; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + throw const "zero"; + } +} + +fn user.safe_caller() -> string { + // Locals: + let _0: string // _0 // return + let _1: unknown // e + + bb0: { + _0 = call const fn user.may_fail(const 1_i64) -> [bb1, unwind: bb2]; + } + + bb1: { + goto -> bb5; + } + + bb2: { + throw_if_panic copy _1 -> bb3; + } + + bb3: { + goto -> bb4; + } + + bb4: { + _0 = const "caught"; + goto -> bb5; + } + + bb5: { + goto -> bb6; + } + + bb6: { + return; + } +} + +fn user.apply(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.apply_and_throw(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + let _2: int // result + let _3: bool + let _4: int + + bb0: { + _2 = call copy _1() -> [bb1]; + } + + bb1: { + _4 = copy _2; + _3 = copy _4 < const 0_i64; + branch copy _3 -> [bb5, bb2]; + } + + bb2: { + goto -> bb3; + } + + bb3: { + _0 = copy _2; + goto -> bb4; + } + + bb4: { + return; + } + + bb5: { + throw const "negative result"; + } +} + +fn user.apply_throwing(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.test_apply_and_throw_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 42 + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_and_throw(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_apply_explicit_throws() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_throwing(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "boom"; + } +} + +fn user.test_apply_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 42 + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_apply_throwing() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "boom"; + } +} + +fn user.test_explicit_throws_match() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void // f + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "boom"; + } +} + +fn user.test_explicit_throws_never_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 42 // f + let _2: () -> 42 + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_explicit_throws_wider_than_body() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 42 // f + let _2: () -> 42 + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_conditional_throw(x: int) -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x // param + let _2: (int) -> int // f + let _3: (int) -> int + + bb0: { + _2 = make_closure lambda[0](); + _3 = copy _2; + _0 = call copy _3(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(n: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // n // param + let _2: bool + + bb0: { + _2 = copy _1 < const 0_i64; + branch copy _2 -> [bb4, bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _0 = copy _1; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + throw const "negative"; + } +} + +fn user.test_multi_throw_types(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: (int) -> "ok" // f + let _3: (int) -> "ok" + + bb0: { + _2 = make_closure lambda[0](); + _3 = copy _2; + _0 = call copy _3(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(n: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // n // param + + bb0: { + switch copy _1 [0: bb5, 1: bb4, otherwise: bb1]; + } + + bb1: { + _0 = const "ok"; + goto -> bb2; + } + + bb2: { + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + throw const 1_i64; + } + + bb5: { + throw const "string error"; + } +} + +fn user.test_pure_lambda() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 42 // f + let _2: () -> 42 + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_throwing_int() -> string { + // Locals: + let _0: string // _0 // return + let _1: () -> void // f + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const 42_i64; + } +} + +fn user.test_throwing_lambda() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void // f + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "boom"; + } +} + +fn user.test_throws_never_but_throws() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void // f + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "boom"; + } +} + +fn user.test_nested_both_throw() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void // outer + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + let _1: () -> void // inner + let _2: void + let _3: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _3 = copy _1; + _2 = call copy _3() -> [bb1]; + } + + bb1: { + throw const "outer"; + } +} + +// lambda[0] +fn ., 1)>() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const 42_i64; + } +} + +fn user.test_nested_inner_throws() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void // outer + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + let _1: () -> void // inner + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn ., 1)>() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "inner boom"; + } +} + +fn user.test_nested_outer_throws() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void // outer + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "outer boom"; + } +} + +// lambda[0] +fn ., 1)>() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 269d9a72b5..129e309a7d 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -9,3 +9,312 @@ function user.takes_throwing(f: () -> int throws string) -> int throws never { } } type user.ThrowingFn$stream = unknown +class user.PureHandler { + run: () -> null +} +class user.ThrowingHandler { + run: () -> null throws string +} +class user.MixedHandler { + safe: () -> int + risky: () -> int throws string +} +function user.make_pure_handler() -> user.PureHandler throws never { + { : user.PureHandler + PureHandler { run: () -> null { ... } } : user.PureHandler + } +} +lambda user.make_pure_handler { +} +function user.make_throwing_handler() -> user.ThrowingHandler throws never { + { : user.ThrowingHandler + ThrowingHandler { run: () -> null { ... } } : user.ThrowingHandler + } +} +lambda user.make_throwing_handler { +} +class user.PureHandler$stream { + run: unknown +} +class user.ThrowingHandler$stream { + run: unknown +} +class user.MixedHandler$stream { + safe: unknown + risky: unknown +} +function user.may_fail(x: int) -> string throws string { + { : "ok" + if (x == 0 : bool) : void + { : never + throw "zero" : "zero" + } + "ok" : "ok" + } +} +function user.always_ok(x: int) -> string throws never { + { : "always ok" + "always ok" : "always ok" + } +} +function user.caller() -> string throws never { + { : string + may_fail(1) : string + } +} +function user.safe_caller() -> string throws never { + { : string | "caught" + catch (may_fail(1) : string) : unknown + catch (e) + _ => + "caught" : "caught" + } +} +type user.ThrowingCallback = () -> int throws string +type user.PureCallback = () -> int +type user.ExplicitPure = () -> int +type user.Mapper = (int) -> string throws string +type user.Wrapper = (() -> int) -> int throws string +type user.ThrowingCallback$stream = unknown +type user.PureCallback$stream = unknown +type user.ExplicitPure$stream = unknown +type user.Mapper$stream = unknown +type user.Wrapper$stream = unknown +function user.apply(f: () -> int) -> int throws never { + { : int + f() : int + } +} +function user.test_apply_pure() -> int throws never { + { : int + apply(() -> int { ... }) : int + () -> int { ... } : () -> 42 + { + 42 + } + } +} +lambda user.test_apply_pure { +} +function user.test_apply_throwing() -> int throws never { + { : int + apply(() -> int { ... }) : int + () -> int { ... } : () -> never throws string + { + throw "boom" + } + } +} +lambda user.test_apply_throwing { +} +function user.apply_throwing(f: () -> int throws string) -> int throws never { + { : int + f() : int + } +} +function user.test_apply_explicit_throws() -> int throws never { + { : int + apply_throwing(() -> int { ... }) : int + () -> int { ... } : () -> never throws string + { + throw "boom" + } + } +} +lambda user.test_apply_explicit_throws { +} +function user.apply_and_throw(f: () -> int) -> int throws never { + { : int + let result = f() : int + if (result < 0 : bool) : void + { : never + throw "negative result" : "negative result" + } + result : int + } +} +function user.test_apply_and_throw_pure() -> int throws never { + { : int + apply_and_throw(() -> int { ... }) : int + () -> int { ... } : () -> 42 + { + 42 + } + } +} +lambda user.test_apply_and_throw_pure { +} +function user.test_explicit_throws_match() -> int throws never { + { : never + let f = : () -> never throws string + () -> int throws string { ... } : () -> never throws string + { + throw "boom" + } + f() : never + } +} +lambda user.test_explicit_throws_match { +} +function user.test_explicit_throws_never_pure() -> int throws never { + { : 42 + let f = : () -> 42 + () -> int throws never { ... } : () -> 42 + { + 42 + } + f() : 42 + } +} +lambda user.test_explicit_throws_never_pure { +} +function user.test_explicit_throws_wider_than_body() -> int throws never { + { : 42 + let f = : () -> 42 throws string + () -> int throws string { ... } : () -> 42 throws string + { + 42 + } + f() : 42 + } +} +lambda user.test_explicit_throws_wider_than_body { + ?? 547..554: extraneous throws declaration: string +} +function user.test_pure_lambda() -> int throws never { + { : 42 + let f = : () -> 42 + () -> int { ... } : () -> 42 + { + 42 + } + f() : 42 + } +} +lambda user.test_pure_lambda { +} +function user.test_throwing_lambda() -> int throws never { + { : never + let f = : () -> never throws string + () -> int { ... } : () -> never throws string + { + throw "boom" + } + f() : never + } +} +lambda user.test_throwing_lambda { +} +function user.test_throwing_int() -> string throws never { + { : never + let f = : () -> never throws int + () -> string { ... } : () -> never throws int + { + throw 42 + } + f() : never + } +} +lambda user.test_throwing_int { +} +function user.test_conditional_throw(x: int) -> int throws never { + { : int + let f = : (n: int) -> int throws string + (n: int) -> int { ... } : (n: int) -> int throws string + { + if (n < 0) + { + throw "negative" + } + n + } + f(x) : int + } +} +lambda user.test_conditional_throw { +} +function user.test_multi_throw_types(x: int) -> string throws never { + { : "ok" + let f = : (n: int) -> "ok" throws int | string + (n: int) -> string { ... } : (n: int) -> "ok" throws int | string + { + match (n) { 0 => throw "string error", 1 => throw 1, _ => "ok" } + } + f(x) : "ok" + } +} +lambda user.test_multi_throw_types { +} +function user.test_throws_never_but_throws() -> int throws never { + { : never + let f = : () -> never + () -> int throws never { ... } : () -> never + { + throw "boom" + } + f() : never + } +} +lambda user.test_throws_never_but_throws { + !! 213..219: throws contract violation: `never` is missing string +} +function user.test_nested_inner_throws() -> int throws never { + { : never + let outer = : () -> never throws string + () -> int { ... } : () -> never throws string + { + let inner = ... + () -> int { ... } + { + throw "inner boom" + } + inner() + } + outer() : never + } +} +lambda user.test_nested_inner_throws { +} +lambda user.test_nested_inner_throws { +} +function user.test_nested_outer_throws() -> int throws never { + { : never + let outer = : () -> never throws string + () -> int { ... } : () -> never throws string + { + let inner = ... + () -> int { ... } + { + 42 + } + throw "outer boom" + } + outer() : never + } +} +lambda user.test_nested_outer_throws { +} +lambda user.test_nested_outer_throws { +} +function user.test_nested_both_throw() -> int throws never { + { : never + let outer = : () -> never throws int | unknown + () -> int { ... } : () -> never throws int | unknown + { + let inner = ... + () -> int { ... } + { + throw 42 + } + inner() + throw "outer" + } + outer() : never + } + ?? 0..0: unreachable code: 1 statement(s) after diverging statement +} +lambda user.test_nested_both_throw { + ?? 568..575: unreachable code: 1 statement(s) after diverging statement +} +lambda user.test_nested_both_throw { +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index 40cf564409..589b199aab 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -2,4 +2,42 @@ source: crates/baml_tests/src/generated_tests.rs --- === COMPILER2 DIAGNOSTICS === -No errors found. + [type] Warning: extraneous throws declaration: string + ╭─[ lambda_throws_explicit.baml:17:27 ] + │ + 17 │ let f = () -> int throws string { 42 } + │ ───┬─── + │ ╰───── extraneous throws declaration: string + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `never` is missing string + ╭─[ lambda_throws_violation.baml:5:27 ] + │ + 5 │ let f = () -> int throws never { throw "boom" } + │ ───┬── + │ ╰──── throws contract violation: `never` is missing string + │ + │ Note: Error code: E0001 +───╯ + + [type] Warning: unreachable code: 1 statement(s) after diverging statement + ╭─[ nested_lambda_throws.baml:1:1 ] + │ + 1 │ // === Nested lambda throws behavior === + │ │ + │ ╰─ unreachable code: 1 statement(s) after diverging statement + │ + │ Note: Error code: E0001 +───╯ + + [type] Warning: unreachable code: 1 statement(s) after diverging statement + ╭─[ nested_lambda_throws.baml:25:5 ] + │ + 25 │ inner() + │ ───┬─── + │ ╰───── unreachable code: 1 statement(s) after diverging statement + │ + │ Note: Error code: E0001 +────╯ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 3af59d72a8..01c949daa0 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -1,8 +1,192 @@ --- source: crates/baml_tests/src/generated_tests.rs --- +function user.always_ok(x: int) -> string { + load_const "always ok" + return +} + +function user.apply(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.apply_and_throw(f: () -> int) -> int { + load_var f + call_indirect + store_var result + load_var result + load_const 0 + cmp_op < + pop_jump_if_false L0 + jump L1 + + L0: + load_var result + return + + L1: + load_const "negative result" + throw +} + +function user.apply_throwing(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.caller() -> string { + load_const 1 + call user.may_fail + return +} + +function user.make_pure_handler() -> PureHandler { + alloc_instance PureHandler + copy 0 + make_closure ., 0 + store_field .run + return +} + +function user.make_throwing_handler() -> ThrowingHandler { + alloc_instance ThrowingHandler + copy 0 + make_closure ., 0 + store_field .run + return +} + +function user.may_fail(x: int) -> string { + load_var x + load_const 0 + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_const "ok" + return + + L1: + load_const "zero" + throw +} + +function user.safe_caller() -> string { + load_const 1 + call user.may_fail + jump L0 + load_var e + throw_if_panic + load_const "caught" + + L0: + return +} + function user.takes_throwing(f: () -> int) -> int { load_var f call_indirect return } + +function user.test_apply_and_throw_pure() -> int { + make_closure ., 0 + call user.apply_and_throw + return +} + +function user.test_apply_explicit_throws() -> int { + make_closure ., 0 + call user.apply_throwing + return +} + +function user.test_apply_pure() -> int { + make_closure ., 0 + call user.apply + return +} + +function user.test_apply_throwing() -> int { + make_closure ., 0 + call user.apply + return +} + +function user.test_conditional_throw(x: int) -> int { + load_var x + make_closure ., 0 + call_indirect + return +} + +function user.test_explicit_throws_match() -> int { + make_closure ., 0 + call_indirect + return +} + +function user.test_explicit_throws_never_pure() -> int { + make_closure ., 0 + call_indirect + return +} + +function user.test_explicit_throws_wider_than_body() -> int { + make_closure ., 0 + call_indirect + return +} + +function user.test_multi_throw_types(x: int) -> string { + load_var x + make_closure ., 0 + call_indirect + return +} + +function user.test_nested_both_throw() -> int { + make_closure ., 0 + call_indirect + return +} + +function user.test_nested_inner_throws() -> int { + make_closure ., 0 + call_indirect + return +} + +function user.test_nested_outer_throws() -> int { + make_closure ., 0 + call_indirect + return +} + +function user.test_pure_lambda() -> int { + make_closure ., 0 + call_indirect + return +} + +function user.test_throwing_int() -> string { + make_closure ., 0 + call_indirect + return +} + +function user.test_throwing_lambda() -> int { + make_closure ., 0 + call_indirect + return +} + +function user.test_throws_never_but_throws() -> int { + make_closure ., 0 + call_indirect + return +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__class_field_fn_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__class_field_fn_throws.snap new file mode 100644 index 0000000000..9f932c7d0f --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__class_field_fn_throws.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Unexpected additional element at class_field_fn_throws.baml:8:18 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_decl_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_decl_throws.snap new file mode 100644 index 0000000000..dd5bc9a100 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_decl_throws.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Expected token/node of kind LLM_FUNCTION_BODY or EXPR_FUNCTION_BODY, but found THROWS_CLAUSE at fn_decl_throws.baml:4:36 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_type_alias_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_type_alias_throws.snap new file mode 100644 index 0000000000..1cf46de82e --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_type_alias_throws.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Unexpected additional element at fn_type_alias_throws.baml:4:34 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_throws.snap new file mode 100644 index 0000000000..901b211f5d --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_throws.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Unexpected additional element at hof_throws.baml:19:37 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_explicit.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_explicit.snap new file mode 100644 index 0000000000..c323917448 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_explicit.snap @@ -0,0 +1,28 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +// === Lambda explicit throws annotations === + +// Lambda with explicit throws string - body matches +function test_explicit_throws_match() -> int { + let f = () -> int throws string { + throw "boom" + } + f() +} + +// Lambda with explicit throws never - body is pure (should be fine) +function test_explicit_throws_never_pure() -> int { + let f = () -> int throws never { + 42 + } + f() +} + +// Lambda with explicit throws string - body is pure (should be fine, never <: string) +function test_explicit_throws_wider_than_body() -> int { + let f = () -> int throws string { + 42 + } + f() +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_infer.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_infer.snap new file mode 100644 index 0000000000..738d80bc60 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_infer.snap @@ -0,0 +1,55 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +// === Lambda throws inference === + +// Pure lambda - should infer throws: never +function test_pure_lambda() -> int { + let f = () -> int { + 42 + } + f() +} + +// Throwing lambda - should infer throws: string +function test_throwing_lambda() -> int { + let f = () -> int { + throw "boom" + } + f() +} + +// Lambda throwing int - should infer throws: int +function test_throwing_int() -> string { + let f = () -> string { + throw 42 + } + f() +} + +// Lambda with conditional throw - only one branch throws +function test_conditional_throw(x: int) -> int { + let f = (n: int) -> int { + if (n < 0) { + throw "negative" + } + n + } + f(x) +} + +// Lambda with multiple throw types - should infer throws: string | int +function test_multi_throw_types(x: int) -> string { + let f = (n: int) -> string { + match (n) { + 0 => { + throw "string error" + }, + 1 => { + throw 1 + }, + _ => "ok", + } + } + f(x) +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_violation.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_violation.snap new file mode 100644 index 0000000000..1b6857dba1 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_violation.snap @@ -0,0 +1,12 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +// === Lambda throws contract violations (should produce errors) === + +// Lambda with explicit throws never but body throws - should error +function test_throws_never_but_throws() -> int { + let f = () -> int throws never { + throw "boom" + } + f() +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__nested_lambda_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__nested_lambda_throws.snap new file mode 100644 index 0000000000..8502e2b03c --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__nested_lambda_throws.snap @@ -0,0 +1,38 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +// === Nested lambda throws behavior === + +// Inner lambda throws, outer doesn't throw directly +function test_nested_inner_throws() -> int { + let outer = () -> int { + let inner = () -> int { + throw "inner boom" + } + inner() + } + outer() +} + +// Outer lambda throws, inner is pure +function test_nested_outer_throws() -> int { + let outer = () -> int { + let inner = () -> int { + 42 + } + throw "outer boom" + } + outer() +} + +// Both throw different types +function test_nested_both_throw() -> int { + let outer = () -> int { + let inner = () -> int { + throw 42 + } + inner(); + throw "outer" + } + outer() +} diff --git a/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_tir.snap b/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_tir.snap index 51b49e83d2..76a19013f0 100644 --- a/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_tir.snap @@ -52,8 +52,8 @@ lambda user.test_map_inferred { } function user.test_throws() -> int throws never { { : int - let risky = : (x: int) -> int - (x: int) -> int throws string { ... } : (x: int) -> int + let risky = : (x: int) -> int throws string + (x: int) -> int throws string { ... } : (x: int) -> int throws string { if (x < 0) { From 1f64000d07e8dd3cf7dc27a21fbe07b5888d71ca Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 05:49:49 -0500 Subject: [PATCH 03/26] baml_language: finalize typed rethrows for higher-order functions --- .../crates/baml_compiler2_tir/src/builder.rs | 730 ++++----- .../src/effective_throws.rs | 741 +++++++++ .../baml_compiler2_tir/src/inference.rs | 63 +- .../crates/baml_compiler2_tir/src/lib.rs | 1 + .../baml_compiler2_tir/src/lower_type_expr.rs | 118 ++ .../array_map_throws.baml | 16 + .../catch_absorbs_throws.baml | 32 + .../function_type_throws/chained_hof.baml | 17 + .../explicit_param_throws.baml | 23 + .../function_type_throws/hof_own_throws.baml | 19 + .../function_type_throws/hof_rethrows.baml | 43 + .../function_type_throws/hof_throws.baml | 19 + .../function_type_throws/mixed_params.baml | 26 + .../returned_closures.baml | 22 + .../baml_tests____baml_std____04_tir.snap | 22 +- .../baml_tests____testing_std____04_tir.snap | 16 +- .../baml_tests__builtin_io__04_tir.snap | 22 +- ...baml_tests__catch_all_keyword__04_tir.snap | 4 +- .../baml_tests__catch_throw__04_tir.snap | 10 +- ...ests__catch_throw_regressions__04_tir.snap | 12 +- .../baml_tests__format_checks__04_tir.snap | 2 +- ...pe_throws__01_lexer__array_map_throws.snap | 114 ++ ...hrows__01_lexer__catch_absorbs_throws.snap | 166 ++ ...on_type_throws__01_lexer__chained_hof.snap | 102 ++ ...rows__01_lexer__explicit_param_throws.snap | 140 ++ ...type_throws__01_lexer__hof_own_throws.snap | 144 ++ ...n_type_throws__01_lexer__hof_rethrows.snap | 304 ++++ ...ion_type_throws__01_lexer__hof_throws.snap | 99 ++ ...n_type_throws__01_lexer__mixed_params.snap | 184 +++ ...e_throws__01_lexer__returned_closures.snap | 120 ++ ...e_throws__02_parser__array_map_throws.snap | 136 ++ ...rows__02_parser__catch_absorbs_throws.snap | 183 +++ ...n_type_throws__02_parser__chained_hof.snap | 127 ++ ...ows__02_parser__explicit_param_throws.snap | 163 ++ ...ype_throws__02_parser__hof_own_throws.snap | 121 ++ ..._type_throws__02_parser__hof_rethrows.snap | 426 +++++ ...on_type_throws__02_parser__hof_throws.snap | 114 ++ ..._type_throws__02_parser__mixed_params.snap | 238 +++ ..._throws__02_parser__returned_closures.snap | 136 ++ ...l_tests__function_type_throws__03_hir.snap | 129 ++ ...tests__function_type_throws__04_5_mir.snap | 1398 ++++++++++++++++- ...l_tests__function_type_throws__04_tir.snap | 428 ++++- ..._function_type_throws__05_diagnostics.snap | 72 + ...sts__function_type_throws__06_codegen.snap | 301 ++++ ...hrows__10_formatter__array_map_throws.snap | 29 + ...s__10_formatter__catch_absorbs_throws.snap | 5 + ...ype_throws__10_formatter__chained_hof.snap | 30 + ...__10_formatter__explicit_param_throws.snap | 5 + ..._throws__10_formatter__hof_own_throws.snap | 32 + ...pe_throws__10_formatter__hof_rethrows.snap | 89 ++ ...pe_throws__10_formatter__mixed_params.snap | 45 + ...rows__10_formatter__returned_closures.snap | 5 + .../baml_tests__lambda_advanced__04_tir.snap | 42 +- .../baml_tests__lambda_basic__04_tir.snap | 4 +- .../baml_tests__lambda_errors__04_tir.snap | 4 +- ...s__namespaces_type_resolution__04_tir.snap | 22 +- .../baml_tests__null_handling__04_tir.snap | 2 +- .../baml_tests__type_annotation__04_tir.snap | 6 +- .../baml_tests/src/compiler2_tir/mod.rs | 84 +- 59 files changed, 7137 insertions(+), 570 deletions(-) create mode 100644 baml_language/crates/baml_compiler2_tir/src/effective_throws.rs create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/array_map_throws.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/catch_absorbs_throws.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/chained_hof.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/explicit_param_throws.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/hof_own_throws.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/hof_rethrows.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/mixed_params.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/returned_closures.baml create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__array_map_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__catch_absorbs_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__chained_hof.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__explicit_param_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_own_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_rethrows.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__mixed_params.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__returned_closures.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__array_map_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__catch_absorbs_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__chained_hof.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__explicit_param_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_own_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_rethrows.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__mixed_params.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__returned_closures.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__array_map_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__catch_absorbs_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__chained_hof.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__explicit_param_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_own_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_rethrows.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__mixed_params.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__returned_closures.snap diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index af81ae3f1f..a765fe139e 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -15,7 +15,7 @@ use std::collections::{BTreeSet, HashMap}; -use baml_base::Name; +use baml_base::{Name, SourceFile}; use baml_compiler2_ast::{Expr, ExprBody, ExprId, PatId, Stmt, StmtId, TypeExpr}; use baml_compiler2_hir::{ contributions::Definition, @@ -188,6 +188,7 @@ impl<'db> TypeInferenceBuilder<'db> { FxHashMap, FxHashMap>, FxHashSet, + FxHashMap>, TypeCheckDiagnostics<'db>, ) { let diagnostics = self.context.finish(); @@ -196,6 +197,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.bindings, self.resolutions, self.exhaustive_matches, + self.catch_residual_throws, diagnostics, ) } @@ -721,7 +723,6 @@ impl<'db> TypeInferenceBuilder<'db> { return_annotation.as_ref(), expr_id, ); - // Determine throws: explicit annotation takes precedence, // otherwise infer from the body. let throws_ty = if let Some(te) = &func_def.throws { @@ -956,7 +957,12 @@ impl<'db> TypeInferenceBuilder<'db> { }; match &callee_ty { - Ty::Function { params, ret, .. } => { + Ty::Function { + params, + ret, + throws: callee_throws, + .. + } => { let effective_params = if is_method_call { crate::generics::skip_self_param(params) } else { @@ -1049,6 +1055,17 @@ impl<'db> TypeInferenceBuilder<'db> { self.context.report_simple(d, expr_id); } + // Phase 2b: if we have bindings and the callee has effect vars + // in its throws, substitute them and update the callee's recorded + // type so that collect_effective_throws_from_expr sees the + // resolved throws rather than unbound TypeVars. + if !bindings.is_empty() && crate::generics::contains_typevar(callee_throws) + { + let substituted_callee_ty = + crate::generics::substitute_ty(&callee_ty, &bindings); + self.record_expr_type(*callee, substituted_callee_ty); + } + // Subtype check against expected type (skip if we did generic // inference — the inference already accounts for expected) if bindings.is_empty() @@ -1192,7 +1209,6 @@ impl<'db> TypeInferenceBuilder<'db> { Some(effective_ret), expr_id, ); - // Determine throws: explicit annotation > expected > inferred from body let throws_ty = if let Some(te) = &func_def.throws { let mut diags = Vec::new(); @@ -2509,204 +2525,15 @@ impl<'db> TypeInferenceBuilder<'db> { } fn collect_effective_throws(&self, body: &ExprBody) -> BTreeSet { - let mut out = BTreeSet::new(); - if let Some(root) = body.root_expr { - self.collect_effective_throws_from_expr(root, body, &mut out); - } - out - } - - fn collect_effective_throws_from_expr( - &self, - expr_id: ExprId, - body: &ExprBody, - out: &mut BTreeSet, - ) { - match &body.exprs[expr_id] { - Expr::Throw { value } => { - self.collect_effective_throws_from_expr(*value, body, out); - self.collect_throw_facts_from_value(*value, out); - } - Expr::Call { callee, args } => { - self.collect_effective_throws_from_expr(*callee, body, out); - for arg in args { - self.collect_effective_throws_from_expr(*arg, body, out); - } - if let Some(target) = self.call_target_name(*callee, body) { - let throws = crate::throw_inference::function_throw_sets( - self.context.db(), - self.package_id, - ); - if let Some(transitive) = throws.transitive_for(&target) { - out.extend(transitive.iter().cloned()); - } else { - // Target not in throw set registry (function parameter, - // external function, etc.). Check if the callee has a - // type-level throws annotation before falling back to Unknown. - if let Some(Ty::Function { throws, .. }) = self.expressions.get(callee) { - let facts = crate::throw_inference::flatten_ty_to_facts(throws); - out.extend(facts); - } else { - out.insert(Ty::Unknown { - attr: TyAttr::default(), - }); - } - } - } else { - // No resolvable target name — check type-level throws. - if let Some(Ty::Function { throws, .. }) = self.expressions.get(callee) { - let facts = crate::throw_inference::flatten_ty_to_facts(throws); - out.extend(facts); - } else { - out.insert(Ty::Unknown { - attr: TyAttr::default(), - }); - } - } - } - Expr::Catch { clauses, .. } => { - if let Some(residual) = self.catch_residual_throws.get(&expr_id) { - out.extend(residual.iter().cloned()); - } - for clause in clauses { - for arm_id in &clause.arms { - let arm = &body.catch_arms[*arm_id]; - self.collect_effective_throws_from_expr(arm.body, body, out); - } - } - } - Expr::If { - condition, - then_branch, - else_branch, - } => { - self.collect_effective_throws_from_expr(*condition, body, out); - self.collect_effective_throws_from_expr(*then_branch, body, out); - if let Some(else_expr) = else_branch { - self.collect_effective_throws_from_expr(*else_expr, body, out); - } - } - Expr::Match { - scrutinee, arms, .. - } => { - self.collect_effective_throws_from_expr(*scrutinee, body, out); - for arm_id in arms { - let arm = &body.match_arms[*arm_id]; - if let Some(guard) = arm.guard { - self.collect_effective_throws_from_expr(guard, body, out); - } - self.collect_effective_throws_from_expr(arm.body, body, out); - } - } - Expr::Binary { lhs, rhs, .. } => { - self.collect_effective_throws_from_expr(*lhs, body, out); - self.collect_effective_throws_from_expr(*rhs, body, out); - } - Expr::Unary { expr, .. } => { - self.collect_effective_throws_from_expr(*expr, body, out); - } - Expr::Object { - fields, spreads, .. - } => { - for (_, value) in fields { - self.collect_effective_throws_from_expr(*value, body, out); - } - for spread in spreads { - self.collect_effective_throws_from_expr(spread.expr, body, out); - } - } - Expr::Array { elements } => { - for elem in elements { - self.collect_effective_throws_from_expr(*elem, body, out); - } - } - Expr::Map { entries } => { - for (key, value) in entries { - self.collect_effective_throws_from_expr(*key, body, out); - self.collect_effective_throws_from_expr(*value, body, out); - } - } - Expr::Block { stmts, tail_expr } => { - for stmt_id in stmts { - self.collect_effective_throws_from_stmt(*stmt_id, body, out); - } - if let Some(tail) = tail_expr { - self.collect_effective_throws_from_expr(*tail, body, out); - } - } - Expr::FieldAccess { base, .. } | Expr::OptionalFieldAccess { base, .. } => { - self.collect_effective_throws_from_expr(*base, body, out); - } - Expr::Index { base, index } | Expr::OptionalIndex { base, index } => { - self.collect_effective_throws_from_expr(*base, body, out); - self.collect_effective_throws_from_expr(*index, body, out); - } - Expr::OptionalCall { callee, args } => { - self.collect_effective_throws_from_expr(*callee, body, out); - for arg in args { - self.collect_effective_throws_from_expr(*arg, body, out); - } - } - Expr::OptionalChain { expr } => { - self.collect_effective_throws_from_expr(*expr, body, out); - } - Expr::Lambda(_) - | Expr::Literal(_) - | Expr::ByteStringLiteral(_) - | Expr::Null - | Expr::Path(_) - | Expr::Missing => {} - } - } - - fn collect_effective_throws_from_stmt( - &self, - stmt_id: StmtId, - body: &ExprBody, - out: &mut BTreeSet, - ) { - match &body.stmts[stmt_id] { - Stmt::Expr(expr) => self.collect_effective_throws_from_expr(*expr, body, out), - Stmt::Let { initializer, .. } => { - if let Some(init) = initializer { - self.collect_effective_throws_from_expr(*init, body, out); - } - } - Stmt::While { - condition, - body: while_body, - after, - .. - } => { - self.collect_effective_throws_from_expr(*condition, body, out); - self.collect_effective_throws_from_expr(*while_body, body, out); - if let Some(after_stmt) = after { - self.collect_effective_throws_from_stmt(*after_stmt, body, out); - } - } - Stmt::For { - collection, - body: for_body, - .. - } => { - self.collect_effective_throws_from_expr(*collection, body, out); - self.collect_effective_throws_from_expr(*for_body, body, out); - } - Stmt::Return(expr) => { - if let Some(expr) = expr { - self.collect_effective_throws_from_expr(*expr, body, out); - } - } - Stmt::Assign { target, value } | Stmt::AssignOp { target, value, .. } => { - self.collect_effective_throws_from_expr(*target, body, out); - self.collect_effective_throws_from_expr(*value, body, out); - } - Stmt::Throw { value } => { - self.collect_effective_throws_from_expr(*value, body, out); - self.collect_throw_facts_from_value(*value, out); - } - Stmt::Break | Stmt::Continue | Stmt::Missing | Stmt::HeaderComment { .. } => {} - } + crate::effective_throws::collect_effective_throws( + self.context.db(), + self.package_id, + body, + &self.expressions, + &self.catch_residual_throws, + true, + true, + ) } fn collect_throw_facts_from_expr( @@ -2718,14 +2545,23 @@ impl<'db> TypeInferenceBuilder<'db> { match &body.exprs[expr_id] { Expr::Throw { value } => { self.collect_throw_facts_from_expr(*value, body, out); - self.collect_throw_facts_from_value(*value, out); + self.collect_throw_facts_from_value(*value, body, out); } Expr::Call { callee, args } => { self.collect_throw_facts_from_expr(*callee, body, out); for arg in args { self.collect_throw_facts_from_expr(*arg, body, out); } - if let Some(target) = self.call_target_name(*callee, body) { + let type_level_facts = self.expressions.get(callee).and_then(|ty| match ty { + Ty::Function { throws, .. } => { + let facts = crate::throw_inference::flatten_ty_to_facts(throws); + if facts.is_empty() { None } else { Some(facts) } + } + _ => None, + }); + if let Some(facts) = type_level_facts { + out.extend(facts); + } else if let Some(target) = self.call_target_name(*callee, body) { let throws = crate::throw_inference::function_throw_sets( self.context.db(), self.package_id, @@ -2872,18 +2708,39 @@ impl<'db> TypeInferenceBuilder<'db> { } Stmt::Throw { value } => { self.collect_throw_facts_from_expr(*value, body, out); - self.collect_throw_facts_from_value(*value, out); + self.collect_throw_facts_from_value(*value, body, out); } Stmt::Break | Stmt::Continue | Stmt::Missing | Stmt::HeaderComment { .. } => {} } } - fn collect_throw_facts_from_value(&self, value_expr_id: ExprId, out: &mut BTreeSet) { - let unknown_ty = Ty::Unknown { - attr: TyAttr::default(), - }; - let thrown_ty = self.expressions.get(&value_expr_id).unwrap_or(&unknown_ty); - out.extend(crate::throw_inference::flatten_ty_to_facts(thrown_ty)); + fn collect_throw_facts_from_value( + &self, + value_expr_id: ExprId, + body: &ExprBody, + out: &mut BTreeSet, + ) { + if let Some(thrown_ty) = self.expressions.get(&value_expr_id) { + out.extend(crate::throw_inference::flatten_ty_to_facts(thrown_ty)); + return; + } + + match &body.exprs[value_expr_id] { + Expr::Literal(lit) => out.extend(crate::throw_inference::flatten_ty_to_facts( + &Ty::Literal(lit.clone(), Freshness::Regular, TyAttr::default()), + )), + Expr::ByteStringLiteral(_) => { + out.insert(Ty::Primitive(PrimitiveType::Uint8Array, TyAttr::default())); + } + Expr::Null => { + out.insert(Ty::Primitive(PrimitiveType::Null, TyAttr::default())); + } + _ => { + out.insert(Ty::Unknown { + attr: TyAttr::default(), + }); + } + } } fn call_target_name(&self, callee_expr_id: ExprId, body: &ExprBody) -> Option { @@ -2921,6 +2778,185 @@ impl<'db> TypeInferenceBuilder<'db> { } } + fn find_function_scope_id( + &self, + file: SourceFile, + span: TextRange, + name: &Name, + ) -> ScopeId<'db> { + let db = self.context.db(); + let index = baml_compiler2_hir::file_semantic_index(db, file); + let file_scope_id = index + .scopes + .iter() + .enumerate() + .find_map(|(i, scope)| { + if scope.kind == baml_compiler2_hir::scope::ScopeKind::Function + && scope.range == span + && scope.name.as_ref() == Some(name) + { + #[allow(clippy::cast_possible_truncation)] + Some(baml_compiler2_hir::scope::FileScopeId::new(i as u32)) + } else { + None + } + }) + .unwrap_or_else(|| index.scope_at_offset(span.start(), Some(name))); + file_scope_id.to_scope_id(db, file) + } + + fn infer_concrete_body_throws( + &self, + body_scope: ScopeId<'db>, + body_package_id: PackageId<'db>, + body: &baml_compiler2_hir::body::FunctionBody, + ) -> BTreeSet { + let db = self.context.db(); + if let baml_compiler2_hir::body::FunctionBody::Expr(ref expr_body) = *body { + let scope_inference = crate::inference::infer_scope_types(db, body_scope); + crate::throw_inference::flatten_ty_to_facts(&scope_inference.effective_throws( + db, + body_package_id, + expr_body, + )) + .into_iter() + .filter(|fact| !matches!(fact, Ty::TypeVar(_, _))) + .collect() + } else { + BTreeSet::new() + } + } + + fn combine_effect_vars_with_body_throws( + synthetic_effect_vars: &[Name], + body_throws_facts: BTreeSet, + ) -> Ty { + let mut all_throws: Vec = synthetic_effect_vars + .iter() + .map(|v| Ty::TypeVar(v.clone(), TyAttr::default())) + .collect(); + all_throws.extend(body_throws_facts); + all_throws.retain(|t| !matches!(t, Ty::Never { .. } | Ty::Void { .. })); + + match all_throws.len() { + 0 => Ty::Never { + attr: TyAttr::default(), + }, + 1 => all_throws.remove(0), + _ => Ty::Union(all_throws, TyAttr::default()), + } + } + + #[allow(clippy::too_many_arguments)] + fn build_function_ty_from_signature( + &self, + pkg_items: &PackageItems<'db>, + ns_context: &[Name], + generic_params: &[Name], + sig: &baml_compiler2_hir::signature::FunctionSignature, + body_scope: Option>, + body_package_id: Option>, + body: Option<&baml_compiler2_hir::body::FunctionBody>, + self_param_ty: Option<&Ty>, + ) -> Ty { + let db = self.context.db(); + let mut diags = Vec::new(); + let mut synthetic_effect_vars: Vec = Vec::new(); + + let params: Vec<(Option, Ty)> = sig + .params + .iter() + .map(|(n, te)| { + let param_ty = if n.as_str() == "self" + && matches!(te, baml_compiler2_ast::TypeExpr::Unknown { .. }) + { + self_param_ty.cloned().unwrap_or(Ty::Unknown { + attr: TyAttr::default(), + }) + } else { + crate::lower_type_expr::lower_type_expr_with_fn_context( + db, + te, + pkg_items, + ns_context, + generic_params, + &mut diags, + &crate::lower_type_expr::FnTypeLoweringContext::DirectParamRoot { + param_name: n.clone(), + }, + &mut synthetic_effect_vars, + ) + }; + (Some(n.clone()), param_ty) + }) + .collect(); + + let effective_generic_params: Vec = generic_params + .iter() + .cloned() + .chain(synthetic_effect_vars.iter().cloned()) + .collect(); + + let ret_ty = sig + .return_type + .as_ref() + .map(|te| { + crate::lower_type_expr::lower_type_expr_in_ns( + db, + te, + pkg_items, + ns_context, + &effective_generic_params, + &mut diags, + ) + }) + .unwrap_or(Ty::Unknown { + attr: TyAttr::default(), + }); + + let throws_ty = sig + .throws + .as_ref() + .map(|te| { + crate::lower_type_expr::lower_type_expr_in_ns( + db, + te, + pkg_items, + ns_context, + generic_params, + &mut diags, + ) + }) + .unwrap_or_else(|| { + let has_callback_param = params + .iter() + .any(|(_, ty)| matches!(ty, Ty::Function { .. })); + if !has_callback_param { + return Ty::Never { + attr: TyAttr::default(), + }; + } + + let body_throws_facts = match (body_scope, body_package_id, body) { + (Some(body_scope), Some(body_package_id), Some(body)) => { + self.infer_concrete_body_throws(body_scope, body_package_id, body) + } + _ => BTreeSet::new(), + }; + Self::combine_effect_vars_with_body_throws( + &synthetic_effect_vars, + body_throws_facts, + ) + }); + + Ty::Function { + params, + ret: Box::new(ret_ty), + throws: Box::new(throws_ty), + attr: TyAttr::default(), + } + } + fn expr_to_path_segments(expr_id: ExprId, body: &ExprBody) -> Option> { match &body.exprs[expr_id] { Expr::Path(segments) if !segments.is_empty() => Some(segments.clone()), @@ -3061,61 +3097,22 @@ impl<'db> TypeInferenceBuilder<'db> { crate::inference::MemberResolution::Free { func_loc }, ); let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); - let mut diags = Vec::new(); - let ty = Ty::Function { - params: sig - .params - .iter() - .map(|(n, te)| { - ( - Some(n.clone()), - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - pkg_items, - &ns_context, - generic_params, - &mut diags, - ), - ) - }) - .collect(), - ret: Box::new( - sig.return_type - .as_ref() - .map(|te| { - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - pkg_items, - &ns_context, - generic_params, - &mut diags, - ) - }) - .unwrap_or(Ty::Unknown { - attr: TyAttr::default(), - }), - ), - throws: Box::new( - sig.throws - .as_ref() - .map(|te| { - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - pkg_items, - &ns_context, - generic_params, - &mut diags, - ) - }) - .unwrap_or(Ty::Never { - attr: TyAttr::default(), - }), - ), - attr: TyAttr::default(), - }; + let func_scope = self.find_function_scope_id( + func_loc.file(db), + func_data_for_sig.span, + &func_data_for_sig.name, + ); + let func_body = baml_compiler2_hir::body::function_body(db, func_loc); + let ty = self.build_function_ty_from_signature( + pkg_items, + &ns_context, + generic_params, + sig.as_ref(), + Some(func_scope), + Some(PackageId::new(db, pkg_info.package)), + Some(func_body.as_ref()), + None, + ); return Some(ty); } @@ -3160,67 +3157,25 @@ impl<'db> TypeInferenceBuilder<'db> { let item_tree = baml_compiler2_ppir::file_item_tree(db, func_loc.file(db)); let func_data = &item_tree[func_loc.id(db)]; let generic_params = &func_data.generic_params; - let sig_ns = - baml_compiler2_hir::file_package::file_package(db, func_loc.file(db)) - .namespace_path; - let mut diags = Vec::new(); - - // Note: diags from referenced function signatures are not - // reported here — they'll be reported at the definition site. - Ty::Function { - params: sig - .params - .iter() - .map(|(n, te)| { - ( - Some(n.clone()), - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - self.package_items, - &sig_ns, - generic_params, - &mut diags, - ), - ) - }) - .collect(), - ret: Box::new( - sig.return_type - .as_ref() - .map(|te| { - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - self.package_items, - &sig_ns, - generic_params, - &mut diags, - ) - }) - .unwrap_or(Ty::Unknown { - attr: TyAttr::default(), - }), - ), - throws: Box::new( - sig.throws - .as_ref() - .map(|te| { - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - self.package_items, - &sig_ns, - generic_params, - &mut diags, - ) - }) - .unwrap_or(Ty::Never { - attr: TyAttr::default(), - }), - ), - attr: TyAttr::default(), - } + let sig_pkg = + baml_compiler2_hir::file_package::file_package(db, func_loc.file(db)); + let sig_ns = sig_pkg.namespace_path; + let func_scope = self.find_function_scope_id( + func_loc.file(db), + func_data.span, + &func_data.name, + ); + let func_body = baml_compiler2_hir::body::function_body(db, func_loc); + self.build_function_ty_from_signature( + self.package_items, + &sig_ns, + generic_params, + sig.as_ref(), + Some(func_scope), + Some(PackageId::new(db, sig_pkg.package)), + Some(func_body.as_ref()), + None, + ) } _ => Ty::Unknown { attr: TyAttr::default(), @@ -3682,67 +3637,23 @@ impl<'db> TypeInferenceBuilder<'db> { all_generic_params.extend(method_data.generic_params.iter().cloned()); let func_loc = baml_compiler2_hir::loc::FunctionLoc::new(db, file, method_id); let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); - let mut diags = Vec::new(); let class_ty = Ty::Class(class_name.clone(), TyAttr::default()); - let ty = Ty::Function { - params: sig - .params - .iter() - .map(|(n, te)| { - let param_ty = if n.as_str() == "self" - && matches!(te, baml_compiler2_ast::TypeExpr::Unknown { .. }) - { - // self with no annotation → use the enclosing class type - class_ty.clone() - } else { - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - pkg_items_for_class, - &ns_context, - &all_generic_params, - &mut diags, - ) - }; - (Some(n.clone()), param_ty) - }) - .collect(), - ret: Box::new( - sig.return_type - .as_ref() - .map(|te| { - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - pkg_items_for_class, - &ns_context, - &all_generic_params, - &mut diags, - ) - }) - .unwrap_or(Ty::Unknown { - attr: TyAttr::default(), - }), - ), - throws: Box::new( - sig.throws - .as_ref() - .map(|te| { - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - pkg_items_for_class, - &ns_context, - &all_generic_params, - &mut diags, - ) - }) - .unwrap_or(Ty::Never { - attr: TyAttr::default(), - }), - ), - attr: TyAttr::default(), - }; + let method_scope = + self.find_function_scope_id(file, method_data.span, &method_data.name); + let method_body = baml_compiler2_hir::body::function_body(db, func_loc); + let ty = self.build_function_ty_from_signature( + pkg_items_for_class, + &ns_context, + &all_generic_params, + sig.as_ref(), + Some(method_scope), + Some(PackageId::new( + db, + baml_compiler2_hir::file_package::file_package(db, file).package, + )), + Some(method_body.as_ref()), + Some(&class_ty), + ); // Note: diags from method signatures are reported at definition site. return Some((ty, class_loc, func_loc)); } @@ -4104,6 +4015,15 @@ impl<'db> TypeInferenceBuilder<'db> { ) }; + // Collect generic params for this method (class + method generics). + let method_generic_params: Vec = class_data + .generic_params + .iter() + .chain(method_data.generic_params.iter()) + .cloned() + .collect(); + + let mut synthetic_effect_vars: Vec = Vec::new(); let params: Vec<(Option, Ty)> = sig .params .iter() @@ -4112,6 +4032,22 @@ impl<'db> TypeInferenceBuilder<'db> { && matches!(te, baml_compiler2_ast::TypeExpr::Unknown { .. }) { builtin_class_ty.clone() + } else if matches!(te, baml_compiler2_ast::TypeExpr::Function { .. }) { + // Function-typed parameter: use DirectParamRoot to create effect var. + let lowered = crate::lower_type_expr::lower_type_expr_with_fn_context( + db, + te, + baml_items, + stub_ns, + &method_generic_params, + &mut diags, + &crate::lower_type_expr::FnTypeLoweringContext::DirectParamRoot { + param_name: n.clone(), + }, + &mut synthetic_effect_vars, + ); + // Apply generic bindings to substitute type args. + crate::generics::substitute_ty(&lowered, &bindings) } else { crate::generics::lower_type_expr_with_generics( db, @@ -4141,6 +4077,22 @@ impl<'db> TypeInferenceBuilder<'db> { .unwrap_or(Ty::Void { attr: TyAttr::default(), }); + + // Compute throws: if we have synthetic effect vars, use them. + let throws_ty = match synthetic_effect_vars.len() { + 0 => Ty::Never { + attr: TyAttr::default(), + }, + 1 => Ty::TypeVar(synthetic_effect_vars[0].clone(), TyAttr::default()), + _ => Ty::Union( + synthetic_effect_vars + .iter() + .map(|v| Ty::TypeVar(v.clone(), TyAttr::default())) + .collect(), + TyAttr::default(), + ), + }; + // Discard diags — they will be reported at the definition site // (the builtin .baml stub). We don't want to spam user code // with unresolved-type errors from builtin signatures. @@ -4149,9 +4101,7 @@ impl<'db> TypeInferenceBuilder<'db> { ty: Ty::Function { params, ret: Box::new(ret), - throws: Box::new(Ty::Never { - attr: TyAttr::default(), - }), + throws: Box::new(throws_ty), attr: TyAttr::default(), }, class_loc, diff --git a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs new file mode 100644 index 0000000000..2873e5e9d4 --- /dev/null +++ b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs @@ -0,0 +1,741 @@ +use std::collections::BTreeSet; + +use baml_base::Name; +use baml_compiler2_ast::{Expr, ExprBody, ExprId, Stmt, StmtId}; +use baml_compiler2_hir::package::PackageId; +use rustc_hash::FxHashMap; + +use crate::{ + throw_inference::{flatten_ty_to_facts, function_throw_sets}, + ty::{Freshness, QualifiedTypeName, Ty, TyAttr}, +}; + +pub(crate) fn collect_effective_throws<'db>( + db: &'db dyn crate::Db, + package_id: PackageId<'db>, + body: &ExprBody, + expressions: &FxHashMap, + catch_residual_throws: &FxHashMap>, + include_typevars: bool, + unknown_on_unresolved_call: bool, +) -> BTreeSet { + let mut out = BTreeSet::new(); + if let Some(root) = body.root_expr { + collect_effective_throws_from_expr( + db, + package_id, + root, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + &mut out, + ); + } + out +} + +#[derive(Clone, Copy)] +struct CallResolutionOptions { + include_typevars: bool, + unknown_on_unresolved_call: bool, +} + +#[allow(clippy::too_many_arguments)] +fn collect_effective_throws_from_expr<'db>( + db: &'db dyn crate::Db, + package_id: PackageId<'db>, + expr_id: ExprId, + body: &ExprBody, + expressions: &FxHashMap, + catch_residual_throws: &FxHashMap>, + include_typevars: bool, + unknown_on_unresolved_call: bool, + out: &mut BTreeSet, +) { + match &body.exprs[expr_id] { + Expr::Throw { value } => { + collect_effective_throws_from_expr( + db, + package_id, + *value, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + collect_throw_facts_from_value(*value, body, expressions, out); + } + Expr::Call { callee, args } => { + collect_effective_throws_from_expr( + db, + package_id, + *callee, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + for arg in args { + collect_effective_throws_from_expr( + db, + package_id, + *arg, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + collect_effective_throws_from_call( + db, + package_id, + *callee, + body, + expressions, + CallResolutionOptions { + include_typevars, + unknown_on_unresolved_call, + }, + out, + ); + } + Expr::Catch { clauses, .. } => { + if let Some(residual) = catch_residual_throws.get(&expr_id) { + out.extend(residual.iter().cloned()); + } + for clause in clauses { + for arm_id in &clause.arms { + let arm = &body.catch_arms[*arm_id]; + collect_effective_throws_from_expr( + db, + package_id, + arm.body, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + } + Expr::If { + condition, + then_branch, + else_branch, + } => { + collect_effective_throws_from_expr( + db, + package_id, + *condition, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + collect_effective_throws_from_expr( + db, + package_id, + *then_branch, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + if let Some(else_expr) = else_branch { + collect_effective_throws_from_expr( + db, + package_id, + *else_expr, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + Expr::Match { + scrutinee, arms, .. + } => { + collect_effective_throws_from_expr( + db, + package_id, + *scrutinee, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + for arm_id in arms { + let arm = &body.match_arms[*arm_id]; + if let Some(guard) = arm.guard { + collect_effective_throws_from_expr( + db, + package_id, + guard, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + collect_effective_throws_from_expr( + db, + package_id, + arm.body, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + Expr::Binary { lhs, rhs, .. } => { + collect_effective_throws_from_expr( + db, + package_id, + *lhs, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + collect_effective_throws_from_expr( + db, + package_id, + *rhs, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + Expr::Unary { expr, .. } => { + collect_effective_throws_from_expr( + db, + package_id, + *expr, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + Expr::Object { + fields, spreads, .. + } => { + for (_, value) in fields { + collect_effective_throws_from_expr( + db, + package_id, + *value, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + for spread in spreads { + collect_effective_throws_from_expr( + db, + package_id, + spread.expr, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + Expr::Array { elements } => { + for elem in elements { + collect_effective_throws_from_expr( + db, + package_id, + *elem, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + Expr::Map { entries } => { + for (key, value) in entries { + collect_effective_throws_from_expr( + db, + package_id, + *key, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + collect_effective_throws_from_expr( + db, + package_id, + *value, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + Expr::Block { stmts, tail_expr } => { + for stmt_id in stmts { + collect_effective_throws_from_stmt( + db, + package_id, + *stmt_id, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + if let Some(tail) = tail_expr { + collect_effective_throws_from_expr( + db, + package_id, + *tail, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + Expr::FieldAccess { base, .. } | Expr::OptionalFieldAccess { base, .. } => { + collect_effective_throws_from_expr( + db, + package_id, + *base, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + Expr::Index { base, index } | Expr::OptionalIndex { base, index } => { + collect_effective_throws_from_expr( + db, + package_id, + *base, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + collect_effective_throws_from_expr( + db, + package_id, + *index, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + Expr::OptionalCall { callee, args } => { + collect_effective_throws_from_expr( + db, + package_id, + *callee, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + for arg in args { + collect_effective_throws_from_expr( + db, + package_id, + *arg, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + Expr::OptionalChain { expr } => { + collect_effective_throws_from_expr( + db, + package_id, + *expr, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + Expr::Lambda(_) + | Expr::Literal(_) + | Expr::ByteStringLiteral(_) + | Expr::Null + | Expr::Path(_) + | Expr::Missing => {} + } +} + +#[allow(clippy::too_many_arguments)] +fn collect_effective_throws_from_stmt<'db>( + db: &'db dyn crate::Db, + package_id: PackageId<'db>, + stmt_id: StmtId, + body: &ExprBody, + expressions: &FxHashMap, + catch_residual_throws: &FxHashMap>, + include_typevars: bool, + unknown_on_unresolved_call: bool, + out: &mut BTreeSet, +) { + match &body.stmts[stmt_id] { + Stmt::Expr(expr) => collect_effective_throws_from_expr( + db, + package_id, + *expr, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ), + Stmt::Let { initializer, .. } => { + if let Some(init) = initializer { + collect_effective_throws_from_expr( + db, + package_id, + *init, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + Stmt::While { + condition, + body: while_body, + after, + .. + } => { + collect_effective_throws_from_expr( + db, + package_id, + *condition, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + collect_effective_throws_from_expr( + db, + package_id, + *while_body, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + if let Some(after_stmt) = after { + collect_effective_throws_from_stmt( + db, + package_id, + *after_stmt, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + Stmt::For { + collection, + body: for_body, + .. + } => { + collect_effective_throws_from_expr( + db, + package_id, + *collection, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + collect_effective_throws_from_expr( + db, + package_id, + *for_body, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + Stmt::Return(expr) => { + if let Some(expr) = expr { + collect_effective_throws_from_expr( + db, + package_id, + *expr, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + } + Stmt::Assign { target, value } | Stmt::AssignOp { target, value, .. } => { + collect_effective_throws_from_expr( + db, + package_id, + *target, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + collect_effective_throws_from_expr( + db, + package_id, + *value, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + } + Stmt::Throw { value } => { + collect_effective_throws_from_expr( + db, + package_id, + *value, + body, + expressions, + catch_residual_throws, + include_typevars, + unknown_on_unresolved_call, + out, + ); + collect_throw_facts_from_value(*value, body, expressions, out); + } + Stmt::Break | Stmt::Continue | Stmt::Missing | Stmt::HeaderComment { .. } => {} + } +} + +fn collect_effective_throws_from_call<'db>( + db: &'db dyn crate::Db, + package_id: PackageId<'db>, + callee_expr_id: ExprId, + body: &ExprBody, + expressions: &FxHashMap, + options: CallResolutionOptions, + out: &mut BTreeSet, +) { + let type_level_facts = expressions.get(&callee_expr_id).and_then(|ty| match ty { + Ty::Function { throws, .. } => Some(flatten_ty_to_facts(throws)), + _ => None, + }); + + if let Some(facts) = type_level_facts { + let filtered: BTreeSet = facts + .iter() + .filter(|fact| options.include_typevars || !matches!(fact, Ty::TypeVar(_, _))) + .cloned() + .collect(); + if !filtered.is_empty() { + out.extend(filtered); + return; + } + if facts.iter().any(|fact| matches!(fact, Ty::TypeVar(_, _))) && !options.include_typevars { + return; + } + } + + if let Some(target) = call_target_name(callee_expr_id, body, expressions) { + let throws = function_throw_sets(db, package_id); + if let Some(transitive) = throws.transitive_for(&target) { + out.extend(transitive.iter().cloned()); + return; + } + } + + if options.unknown_on_unresolved_call + && !matches!(expressions.get(&callee_expr_id), Some(Ty::Function { .. })) + { + out.insert(Ty::Unknown { + attr: TyAttr::default(), + }); + } +} + +fn collect_throw_facts_from_value( + value_expr_id: ExprId, + body: &ExprBody, + expressions: &FxHashMap, + out: &mut BTreeSet, +) { + if let Some(thrown_ty) = expressions.get(&value_expr_id) { + out.extend(flatten_ty_to_facts(thrown_ty)); + return; + } + + match &body.exprs[value_expr_id] { + Expr::Literal(lit) => { + out.extend(flatten_ty_to_facts(&Ty::Literal( + lit.clone(), + Freshness::Regular, + TyAttr::default(), + ))); + } + Expr::ByteStringLiteral(_) => { + out.insert(Ty::Primitive( + crate::ty::PrimitiveType::Uint8Array, + TyAttr::default(), + )); + } + Expr::Null => { + out.insert(Ty::Primitive( + crate::ty::PrimitiveType::Null, + TyAttr::default(), + )); + } + _ => { + out.insert(Ty::Unknown { + attr: TyAttr::default(), + }); + } + } +} + +fn call_target_name( + callee_expr_id: ExprId, + body: &ExprBody, + expressions: &FxHashMap, +) -> Option { + match &body.exprs[callee_expr_id] { + Expr::Path(segments) if !segments.is_empty() => Some(path_name(segments)), + Expr::FieldAccess { base, field } => { + if let Some(Ty::Class(qn, _)) = expressions.get(base) { + Some(class_method_key(qn, field)) + } else { + let mut segments = expr_to_path_segments(*base, body)?; + segments.push(field.clone()); + Some(path_name(&segments)) + } + } + _ => None, + } +} + +fn expr_to_path_segments(expr_id: ExprId, body: &ExprBody) -> Option> { + match &body.exprs[expr_id] { + Expr::Path(segments) if !segments.is_empty() => Some(segments.clone()), + Expr::FieldAccess { base, field } => { + let mut segments = expr_to_path_segments(*base, body)?; + segments.push(field.clone()); + Some(segments) + } + _ => None, + } +} + +fn class_method_key(class_name: &QualifiedTypeName, method: &Name) -> Name { + let ns = class_name.namespace(); + let key = if ns.is_empty() { + format!("{}.{}", class_name.name(), method) + } else { + let ns_str = ns.iter().map(Name::as_str).collect::>().join("."); + format!("{ns_str}.{}.{}", class_name.name(), method) + }; + Name::new(key) +} + +fn path_name(segments: &[Name]) -> Name { + if segments.len() == 1 { + segments[0].clone() + } else { + Name::new( + segments + .iter() + .map(Name::as_str) + .collect::>() + .join("."), + ) + } +} diff --git a/baml_language/crates/baml_compiler2_tir/src/inference.rs b/baml_language/crates/baml_compiler2_tir/src/inference.rs index 93566486e7..c09821dbc0 100644 --- a/baml_language/crates/baml_compiler2_tir/src/inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/inference.rs @@ -85,6 +85,9 @@ pub struct ScopeInference<'db> { resolutions: FxHashMap>, /// Match expressions that the exhaustiveness checker determined cover all cases. exhaustive_matches: FxHashSet, + /// Residual throws for each catch expression: the types that propagate + /// out of the catch (types thrown but not handled by any arm). + catch_residual_throws: FxHashMap>, /// Diagnostics and other rare data. Heap-allocated only when non-empty. extra: Option>>, } @@ -157,6 +160,56 @@ impl<'db> ScopeInference<'db> { self.exhaustive_matches.iter() } + /// Compute the concrete escaping throws for a body using post-inference + /// expression types plus the package throw graph fallback for named calls. + /// + /// This intentionally excludes `Ty::TypeVar` facts so callers can use it as + /// the function body's own concrete throws and union it with synthetic + /// callback effect vars separately. + pub fn effective_throws( + &self, + db: &'db dyn crate::Db, + package_id: PackageId<'db>, + body: &baml_compiler2_ast::ExprBody, + ) -> crate::ty::Ty { + use std::collections::BTreeSet; + + use crate::ty::{PrimitiveType, TyAttr}; + + let mut facts: BTreeSet = crate::effective_throws::collect_effective_throws( + db, + package_id, + body, + &self.expressions, + &self.catch_residual_throws, + false, + false, + ); + + // Remove Never and Void facts (they don't represent thrown exceptions). + facts.retain(|f| !matches!(f, Ty::Never { .. } | Ty::Void { .. })); + + // Widen string literals to string primitive (matches throw_inference behavior). + let widened: BTreeSet = facts + .into_iter() + .map(|f| match &f { + Ty::Literal(baml_compiler2_ast::Literal::String(_), _, _) => { + Ty::Primitive(PrimitiveType::String, TyAttr::default()) + } + other => other.clone(), + }) + .collect(); + + let mut members: Vec = widened.into_iter().collect(); + match members.len() { + 0 => Ty::Never { + attr: TyAttr::default(), + }, + 1 => members.remove(0), + _ => Ty::Union(members, TyAttr::default()), + } + } + /// Get diagnostics for this scope (empty slice if none). pub fn diagnostics(&self) -> &TypeCheckDiagnostics<'db> { self.extra @@ -582,7 +635,14 @@ pub fn infer_scope_types<'db>( } } - let (expressions, bindings, resolutions, exhaustive_matches, diagnostics) = builder.finish(); + let ( + expressions, + bindings, + resolutions, + exhaustive_matches, + catch_residual_throws, + diagnostics, + ) = builder.finish(); let extra = if diagnostics.is_empty() { None @@ -595,6 +655,7 @@ pub fn infer_scope_types<'db>( bindings, resolutions, exhaustive_matches, + catch_residual_throws, extra, } } diff --git a/baml_language/crates/baml_compiler2_tir/src/lib.rs b/baml_language/crates/baml_compiler2_tir/src/lib.rs index deac4bb3cc..abd7fbf431 100644 --- a/baml_language/crates/baml_compiler2_tir/src/lib.rs +++ b/baml_language/crates/baml_compiler2_tir/src/lib.rs @@ -19,6 +19,7 @@ pub mod analysis; pub mod builder; pub mod cycle_detector; +pub mod effective_throws; pub mod generics; pub mod infer_context; pub mod inference; diff --git a/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs b/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs index cd785a78d5..a46ae8a7ff 100644 --- a/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs +++ b/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs @@ -1,5 +1,6 @@ //! `TypeExpr → Ty` lowering using package-level name resolution. +use baml_base::Name; use baml_compiler2_ast::TypeExpr; use baml_compiler2_hir::{ contributions::Definition, @@ -11,6 +12,123 @@ use crate::{ ty::{Freshness, PrimitiveType, QualifiedTypeName, Ty, TyAttr}, }; +/// Context for lowering function types — determines how an omitted `throws` +/// clause is interpreted. +/// +/// - `DefaultClosed`: omitted throws ⇒ `Ty::Never` (pure by default). +/// Used for class fields, type aliases, return types, locals, and nested +/// function types inside parameter/return positions. +/// - `DirectParamRoot { param_name }`: omitted throws ⇒ fresh effect `TypeVar` +/// named `__throws_`. Used only for the *outermost* function +/// type of a direct callback parameter. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FnTypeLoweringContext { + /// Omitted throws ⇒ `Ty::Never`. + DefaultClosed, + /// Omitted throws ⇒ fresh effect `TypeVar` named `__throws_`. + DirectParamRoot { param_name: Name }, +} + +/// Lower a function-typed `TypeExpr` with implicit effect polymorphism. +/// +/// When `ctx` is `DirectParamRoot { param_name }` and the function type has no +/// explicit `throws` clause, a fresh `TypeVar` named `__throws_` is +/// generated and recorded in `synthetic_effect_vars`. All nested positions +/// (params, return type, nested function types) are lowered with +/// `DefaultClosed`. +/// +/// Returns the lowered `Ty`. Any generated effect var names are pushed into +/// `synthetic_effect_vars` so the caller can include them in the generic +/// binding set. +#[allow(clippy::too_many_arguments)] +pub fn lower_type_expr_with_fn_context( + db: &dyn crate::Db, + type_expr: &TypeExpr, + package_items: &PackageItems<'_>, + ns_context: &[Name], + generic_params: &[Name], + diagnostics: &mut Vec, + ctx: &FnTypeLoweringContext, + synthetic_effect_vars: &mut Vec, +) -> Ty { + match type_expr { + TypeExpr::Function { + params, + ret, + throws, + .. + } => { + // Nested function types inside params/return are always DefaultClosed. + let param_tys: Vec<(Option, Ty)> = params + .iter() + .map(|p| { + ( + p.name.clone(), + lower_type_expr_in_ns( + db, + &p.ty, + package_items, + ns_context, + generic_params, + diagnostics, + ), + ) + }) + .collect(); + + let ret_ty = lower_type_expr_in_ns( + db, + ret, + package_items, + ns_context, + generic_params, + diagnostics, + ); + + let throws_ty = match throws { + Some(t) => { + // Explicit throws clause — lower normally. + lower_type_expr_in_ns( + db, + t, + package_items, + ns_context, + generic_params, + diagnostics, + ) + } + None => match ctx { + FnTypeLoweringContext::DirectParamRoot { param_name } => { + // No explicit throws + direct param ⇒ fresh effect TypeVar. + let var_name = Name::new(format!("__throws_{param_name}")); + synthetic_effect_vars.push(var_name.clone()); + Ty::TypeVar(var_name, TyAttr::default()) + } + FnTypeLoweringContext::DefaultClosed => Ty::Never { + attr: TyAttr::default(), + }, + }, + }; + + Ty::Function { + params: param_tys, + ret: Box::new(ret_ty), + throws: Box::new(throws_ty), + attr: TyAttr::default(), + } + } + // Non-function types fall through to normal lowering. + _ => lower_type_expr_in_ns( + db, + type_expr, + package_items, + ns_context, + generic_params, + diagnostics, + ), + } +} + /// Resolve an AST `TypeExpr` to a `Ty` using package-level name resolution. /// /// Names are resolved against `package_items`: classes, enums, and type aliases diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/array_map_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/array_map_throws.baml new file mode 100644 index 0000000000..f364825be1 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/array_map_throws.baml @@ -0,0 +1,16 @@ +// === Array.map with throwing callback === + +// Pure map callback +function test_map_pure() -> int[] { + let items: int[] = [1, 2, 3] + items.map((x) -> { x * 2 }) +} + +// Throwing map callback +function test_map_throwing() -> int[] { + let items: int[] = [1, 2, 3] + items.map((x) -> { + if (x == 2) { throw "found two" } + x * 2 + }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/catch_absorbs_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/catch_absorbs_throws.baml new file mode 100644 index 0000000000..bd21c4a5a0 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/catch_absorbs_throws.baml @@ -0,0 +1,32 @@ +// === catch should absorb throws === + +function may_fail(x: int) -> string throws string { + if (x == 0) { throw "zero" } + "ok" +} + +// Calling a throwing function without catch - should propagate throws +function test_no_catch() -> string { + may_fail(1) +} + +// Catching with wildcard - should absorb throws +function test_catch_all() -> string { + may_fail(1) catch (e) { + _ => "caught" + } +} + +// Catching specific type - should absorb that type +function test_catch_string() -> string { + may_fail(1) catch (e) { + _: string => "caught string" + } +} + +// Catch then re-throw - should propagate the re-thrown type +function test_catch_rethrow() -> string { + may_fail(1) catch (e) { + _ => throw 42 + } +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/chained_hof.baml b/baml_language/crates/baml_tests/projects/function_type_throws/chained_hof.baml new file mode 100644 index 0000000000..d1b501af7d --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/chained_hof.baml @@ -0,0 +1,17 @@ +// === Chained HOF: throws through multiple layers === + +function apply_inner(f: () -> int) -> int { f() } + +function apply_outer(g: () -> int) -> int { + apply_inner(g) +} + +// Pure through two layers +function test_chained_pure() -> int { + apply_outer(() -> int { 42 }) +} + +// Throwing through two layers +function test_chained_throwing() -> int { + apply_outer(() -> int { throw "deep" }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/explicit_param_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/explicit_param_throws.baml new file mode 100644 index 0000000000..c43c20207c --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/explicit_param_throws.baml @@ -0,0 +1,23 @@ +// === Explicit throws on callback param should NOT create effect var === + +// When the callback param has explicit throws, no __throws_ var should appear +function apply_explicit(f: () -> int throws string) -> int { + f() +} + +function test_explicit_param_pure() -> int { + apply_explicit(() -> int { 42 }) +} + +function test_explicit_param_throwing() -> int { + apply_explicit(() -> int { throw "ok" }) +} + +// Explicit throws never on param - pure callbacks only +function apply_pure_only(f: () -> int throws never) -> int { + f() +} + +function test_pure_only_pure() -> int { + apply_pure_only(() -> int { 42 }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/hof_own_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/hof_own_throws.baml new file mode 100644 index 0000000000..6bbaeba407 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/hof_own_throws.baml @@ -0,0 +1,19 @@ +// === HOF with its own throws + callback throws === +// A function that both throws itself AND rethrows callback effects. +// The effective throws should be the UNION of both. + +function apply_guarded(f: () -> int) -> int { + let result = f() + if (result < 0) { throw "negative result" } + result +} + +// Pure callback: function's own throw should still appear +function test_guarded_pure() -> int { + apply_guarded(() -> int { 42 }) +} + +// Throwing callback: both the function's own throw and the callback's throw +function test_guarded_throwing() -> int { + apply_guarded(() -> int { throw 99 }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/hof_rethrows.baml b/baml_language/crates/baml_tests/projects/function_type_throws/hof_rethrows.baml new file mode 100644 index 0000000000..5da9797b74 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/hof_rethrows.baml @@ -0,0 +1,43 @@ +// === HOF rethrows: effect propagation through higher-order functions === + +// Basic rethrows - pure callback +function run_pure(f: () -> int) -> int { f() } +function test_run_pure() -> int { + run_pure(() -> int { 42 }) +} + +// Basic rethrows - throwing callback +function run_throwing(f: () -> int) -> int { f() } +function test_run_throwing() -> int { + run_throwing(() -> int { throw "boom" }) +} + +// Rethrows with multiple callback params +function run_two(f: () -> int, g: () -> int) -> int { + f() + g() +} + +function test_two_pure() -> int { + run_two(() -> int { 1 }, () -> int { 2 }) +} + +function test_two_one_throws() -> int { + run_two(() -> int { 1 }, () -> int { throw "boom" }) +} + +function test_two_both_throw() -> int { + run_two(() -> int { throw "a" }, () -> int { throw 42 }) +} + +// Rethrows with parametric callback +function map_it(x: int, f: (int) -> string) -> string { + f(x) +} + +function test_map_pure() -> string { + map_it(1, (n: int) -> string { "ok" }) +} + +function test_map_throwing() -> string { + map_it(1, (n: int) -> string { throw "bad" }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml index 94b0004d1f..8d1c32fc4a 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml @@ -34,3 +34,22 @@ function apply_and_throw(f: () -> int) -> int { function test_apply_and_throw_pure() -> int { apply_and_throw(() -> int { 42 }) } + +// Function with omitted throws that still throws from its body. +function helper_with_body_throw() -> int { + throw "helper boom" +} + +// Wrapper should keep the helper's concrete body throw alongside callback effects. +function apply_with_helper(f: () -> int) -> int { + helper_with_body_throw() + f() +} + +function test_apply_with_helper_pure() -> int { + apply_with_helper(() -> int { 42 }) +} + +function test_apply_with_helper_throwing() -> int { + apply_with_helper(() -> int { throw 42 }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/mixed_params.baml b/baml_language/crates/baml_tests/projects/function_type_throws/mixed_params.baml new file mode 100644 index 0000000000..33bce92286 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/mixed_params.baml @@ -0,0 +1,26 @@ +// === Mixed: function with both function and non-function params === + +// Only the function param should get an effect var +function apply_with_arg(x: int, f: (int) -> int) -> int { + f(x) +} + +function test_mixed_pure() -> int { + apply_with_arg(5, (n: int) -> int { n * 2 }) +} + +function test_mixed_throwing() -> int { + apply_with_arg(5, (n: int) -> int { + if (n < 0) { throw "negative" } + n * 2 + }) +} + +// Multiple non-function params with one callback +function apply_with_many_args(a: int, b: string, f: (int, string) -> string) -> string { + f(a, b) +} + +function test_many_args_pure() -> string { + apply_with_many_args(1, "hello", (n: int, s: string) -> string { s }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/returned_closures.baml b/baml_language/crates/baml_tests/projects/function_type_throws/returned_closures.baml new file mode 100644 index 0000000000..1be16e1bb5 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/returned_closures.baml @@ -0,0 +1,22 @@ +// === Returned closures: should default to throws never === + +// Returning a pure closure - should be fine +function make_pure() -> (() -> int) { + () -> int { 42 } +} + +// Returning a closure with explicit throws +function make_thrower() -> (() -> int throws string) { + () -> int { throw "error" } +} + +// Using a returned closure +function test_use_pure() -> int { + let f = make_pure() + f() +} + +function test_use_thrower() -> int { + let f = make_thrower() + f() +} diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap index af8005ccd3..e534fa99cc 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap @@ -10,12 +10,12 @@ class baml.Array { } class baml.Map { } -function baml.Map.map_keys(self: baml.Map, f: (unknown) -> U) -> U[] throws never { +function baml.Map.map_keys(self: baml.Map, f: (unknown) -> U throws __throws_f) -> U[] throws __throws_f { { : U[] self.keys().map(f) : U[] } } -function baml.Map.map_values(self: baml.Map, f: (unknown) -> U) -> U[] throws never { +function baml.Map.map_values(self: baml.Map, f: (unknown) -> U throws __throws_f) -> U[] throws __throws_f { { : U[] self.values().map(f) : U[] } @@ -28,7 +28,7 @@ class baml.Map$stream { --- /baml/core.baml --- --- /baml/ns_env/env.baml --- -function baml.env.get_or_panic(key: string) -> string throws never { +function baml.env.get_or_panic(key: string) -> string throws baml.errors.Io { { : "" | string let val = get(key) : string? if (val == null : bool) : "" | string @@ -183,7 +183,7 @@ function baml.llm.PlannerState.rr_pick_index(self: baml.llm.PlannerState, client counter + local_offset % sub_client_count : int } } -function baml.llm.render_prompt(client: baml.llm.Client, function_name: string, args: map) -> baml.llm.PromptAst throws never { +function baml.llm.render_prompt(client: baml.llm.Client, function_name: string, args: map) -> baml.llm.PromptAst throws baml.errors.InvalidArgument | baml.errors.LlmClient | baml.errors.RenderPrompt | unknown { { : baml.llm.PromptAst let jinja_string = get_jinja_template(function_name) : string let return_type = get_return_type(function_name) : type @@ -192,20 +192,20 @@ function baml.llm.render_prompt(client: baml.llm.Client, function_name: string, primitive_client.specialize_prompt(prompt) : baml.llm.PromptAst } } -function baml.llm.build_request(client: baml.llm.Client, function_name: string, args: map) -> baml.http.Request throws never { +function baml.llm.build_request(client: baml.llm.Client, function_name: string, args: map) -> baml.http.Request throws baml.errors.LlmClient | unknown { { : baml.http.Request let primitive_client = client.to_primitive_client() : baml.llm.PrimitiveClient let specialized_prompt = render_prompt(client, function_name, args) : baml.llm.PromptAst primitive_client.build_request(specialized_prompt) : baml.http.Request } } -function baml.llm.parse(function_name: string, json: string) -> T throws never { +function baml.llm.parse(function_name: string, json: string) -> T throws baml.errors.InvalidArgument | baml.errors.LlmClient { { : T let return_type = get_return_type(function_name) : type __sap_parse(json, return_type) : T } } -function baml.llm.call_llm_function(client: baml.llm.Client, function_name: string, args: map) -> T throws never { +function baml.llm.call_llm_function(client: baml.llm.Client, function_name: string, args: map) -> T throws baml.errors.InvalidArgument | unknown { { : T let jinja_string = get_jinja_template(function_name) : string let context = ExecutionContext { jinja_string: jinja_string, args: args, function_name: function_name } : baml.llm.ExecutionContext @@ -248,7 +248,7 @@ class baml.llm.Client { retry: baml.llm.RetryPolicy? counter: int } -function baml.llm.Client.to_primitive_client(self: baml.llm.Client) -> baml.llm.PrimitiveClient throws never { +function baml.llm.Client.to_primitive_client(self: baml.llm.Client) -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument | unknown { { : baml.llm.PrimitiveClient if (self.name.includes("/") : bool) : baml.llm.PrimitiveClient { : never @@ -266,7 +266,7 @@ function baml.llm.Client.build_attempt(self: baml.llm.Client) -> baml.llm.Orches self.build_attempt_with_state(planner_state) : baml.llm.OrchestrationStep[] } } -function baml.llm.Client.build_attempt_with_state(self: baml.llm.Client, planner_state: baml.llm.PlannerState) -> baml.llm.OrchestrationStep[] throws never { +function baml.llm.Client.build_attempt_with_state(self: baml.llm.Client, planner_state: baml.llm.PlannerState) -> baml.llm.OrchestrationStep[] throws baml.errors.InvalidArgument { { : baml.llm.OrchestrationStep[] | never[] match (self.client_type : baml.llm.ClientType) : baml.llm.OrchestrationStep[] | never[] ClientType.Primitive => @@ -345,7 +345,7 @@ function baml.llm.Client.build_plan_with_state(self: baml.llm.Client, planner_st self.build_attempt_with_state(planner_state) : baml.llm.OrchestrationStep[] } } -function baml.llm.Client.execute(self: baml.llm.Client, context: baml.llm.ExecutionContext, inherited_delay_ms: int) -> T throws never { +function baml.llm.Client.execute(self: baml.llm.Client, context: baml.llm.ExecutionContext, inherited_delay_ms: int) -> T throws unknown { { : never match (self.retry : baml.llm.RetryPolicy?) : never r: RetryPolicy => @@ -392,7 +392,7 @@ function baml.llm.Client.execute(self: baml.llm.Client, context: baml.llm.Exe } } } -function baml.llm.Client.execute_once(self: baml.llm.Client, context: baml.llm.ExecutionContext, active_delay_ms: int) -> T throws never { +function baml.llm.Client.execute_once(self: baml.llm.Client, context: baml.llm.ExecutionContext, active_delay_ms: int) -> T throws baml.errors.InvalidArgument | baml.errors.Io | baml.errors.LlmClient | baml.errors.RenderPrompt | baml.errors.Timeout | unknown { { : T match (self.client_type : baml.llm.ClientType) : T ClientType.Primitive => diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap index 1153c68c2a..06a79c288e 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap @@ -26,7 +26,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector thro TestCollector { prefix: prefix, tests: [], testsets: [] } : testing.TestCollector } } -function testing.TestCollector.register_test(self: testing.TestCollector, name: string, body: () -> null, runner: testing.TestRunner?) -> null throws never { +function testing.TestCollector.register_test(self: testing.TestCollector, name: string, body: () -> null throws __throws_body, runner: testing.TestRunner?) -> null throws __throws_body { { : null let full_name = : string if (self.prefix == "" : bool) : string @@ -63,7 +63,7 @@ function testing.TestCollector.register_test(self: testing.TestCollector, name: null : null } } -function testing.TestCollector.register_test_set(self: testing.TestCollector, name: string, collector: (testing.TestCollector) -> null, runner: testing.TestSetRunner?) -> null throws never { +function testing.TestCollector.register_test_set(self: testing.TestCollector, name: string, collector: (testing.TestCollector) -> null throws __throws_collector, runner: testing.TestSetRunner?) -> null throws __throws_collector { { : null let full_name = : string if (self.prefix == "" : bool) : string @@ -100,7 +100,7 @@ function testing.TestCollector.register_test_set(self: testing.TestCollector, na null : null } } -function testing.TestCollector.find_testset(self: testing.TestCollector, name: string) -> testing.TestSetRegistration throws never { +function testing.TestCollector.find_testset(self: testing.TestCollector, name: string) -> testing.TestSetRegistration throws string { { : never for ts in self.testsets { : void @@ -112,7 +112,7 @@ function testing.TestCollector.find_testset(self: testing.TestCollector, name: s throw "TestSet not found: " + name : string } } -function testing.$invoke_collector(collector: (testing.TestCollector) -> null, c: testing.TestCollector) -> null throws never { +function testing.$invoke_collector(collector: (testing.TestCollector) -> null throws __throws_collector, c: testing.TestCollector) -> null throws __throws_collector { { : null collector(c) : null } @@ -160,7 +160,7 @@ function testing.TestRegistry.serialize(self: testing.TestRegistry) -> testing.S return items : testing.SerializedTestDef[] } } -function testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) -> testing.TestReport throws never { +function testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) -> testing.TestReport throws string | unknown { { : never for t in self.collector.tests { : void @@ -180,7 +180,7 @@ function testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) throw "Test not found: " + name : string } } -function testing.TestRegistry.expand_set(self: testing.TestRegistry, name: string) -> testing.SerializedTestDef[] throws never { +function testing.TestRegistry.expand_set(self: testing.TestRegistry, name: string) -> testing.SerializedTestDef[] throws string | unknown { { : never for ts in self.collector.testsets { : void @@ -206,7 +206,7 @@ function testing.TestRegistry.expand_set(self: testing.TestRegistry, name: strin throw "TestSet not found..." + name : string } } -function testing.run_test(body: () -> null, runner: testing.TestRunner?) -> testing.TestReport throws never { +function testing.run_test(body: () -> null throws __throws_body, runner: testing.TestRunner?) -> testing.TestReport throws __throws_body { { : testing.TestReport let base_run = : () -> testing.TestReport () -> TestReport { ... } : () -> testing.TestReport @@ -227,7 +227,7 @@ function testing.run_test(body: () -> null, runner: testing.TestRunner?) -> test } lambda testing.run_test { } -function testing.run_testset(run_children: () -> testing.TestSetReport, runner: testing.TestSetRunner?) -> testing.TestSetReport throws never { +function testing.run_testset(run_children: () -> testing.TestSetReport throws __throws_run_children, runner: testing.TestSetRunner?) -> testing.TestSetReport throws __throws_run_children { { : testing.TestSetReport let effective_run = : () -> testing.TestSetReport match (runner : testing.TestSetRunner?) : () -> testing.TestSetReport diff --git a/baml_language/crates/baml_tests/snapshots/builtin_io/baml_tests__builtin_io__04_tir.snap b/baml_language/crates/baml_tests/snapshots/builtin_io/baml_tests__builtin_io__04_tir.snap index 01a1ea92e4..18f9eb73e8 100644 --- a/baml_language/crates/baml_tests/snapshots/builtin_io/baml_tests__builtin_io__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/builtin_io/baml_tests__builtin_io__04_tir.snap @@ -2,18 +2,18 @@ source: crates/baml_tests/src/generated_tests.rs --- === TIR2 === -function user.ReadFile() -> string throws never { +function user.ReadFile() -> string throws baml.errors.Io { { : string let f = baml.fs.open("example.txt") : baml.fs.File f.read() : string } } -function user.ReadFileInline() -> string throws never { +function user.ReadFileInline() -> string throws baml.errors.Io { { : string baml.fs.open("data.json").read() : string } } -function user.ReadMultipleFiles() -> string throws never { +function user.ReadMultipleFiles() -> string throws baml.errors.Io { { : string let f1 = baml.fs.open("file1.txt") : baml.fs.File let f2 = baml.fs.open("file2.txt") : baml.fs.File @@ -22,7 +22,7 @@ function user.ReadMultipleFiles() -> string throws never { content1 + content2 : string } } -function user.ReadAndClose() -> string throws never { +function user.ReadAndClose() -> string throws baml.errors.Io { { : string let f = baml.fs.open("example.txt") : baml.fs.File let content = f.read() : string @@ -30,18 +30,18 @@ function user.ReadAndClose() -> string throws never { content : string } } -function user.ConnectAndRead() -> string throws never { +function user.ConnectAndRead() -> string throws baml.errors.Io | baml.errors.Timeout { { : string let sock = baml.net.connect("127.0.0.1:8080") : baml.net.Socket sock.read() : string } } -function user.ConnectAndReadInline() -> string throws never { +function user.ConnectAndReadInline() -> string throws baml.errors.Io | baml.errors.Timeout { { : string baml.net.connect("localhost:3000").read() : string } } -function user.MultipleConnections() -> string throws never { +function user.MultipleConnections() -> string throws baml.errors.Io | baml.errors.Timeout { { : string let s1 = baml.net.connect("server1:80") : baml.net.Socket let s2 = baml.net.connect("server2:80") : baml.net.Socket @@ -50,7 +50,7 @@ function user.MultipleConnections() -> string throws never { data1 + data2 : string } } -function user.ConnectReadAndClose() -> string throws never { +function user.ConnectReadAndClose() -> string throws baml.errors.Io | baml.errors.Timeout { { : string let sock = baml.net.connect("127.0.0.1:8080") : baml.net.Socket let data = sock.read() : string @@ -58,18 +58,18 @@ function user.ConnectReadAndClose() -> string throws never { data : string } } -function user.RunCommand() -> string throws never { +function user.RunCommand() -> string throws baml.errors.Io { { : string baml.sys.shell("echo hello") : string } } -function user.RunCommandWithVar() -> string throws never { +function user.RunCommandWithVar() -> string throws baml.errors.Io { { : string let cmd = "ls -la" : "ls -la" -> string baml.sys.shell(cmd) : string } } -function user.ChainCommands() -> string throws never { +function user.ChainCommands() -> string throws baml.errors.Io { { : string let result1 = baml.sys.shell("pwd") : string let result2 = baml.sys.shell("whoami") : string diff --git a/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__04_tir.snap b/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__04_tir.snap index 9ae1f9d681..a7f61b47e0 100644 --- a/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__04_tir.snap @@ -2,7 +2,7 @@ source: crates/baml_tests/src/generated_tests.rs --- === TIR2 === -function user.MayFailIntOrString(x: int) -> string throws never { +function user.MayFailIntOrString(x: int) -> string throws int | string { { : "ok" match (x : int) : "ok" 0 => @@ -13,7 +13,7 @@ function user.MayFailIntOrString(x: int) -> string throws never { "ok" : "ok" } } -function user.PartialCatch(x: int) -> string throws never { +function user.PartialCatch(x: int) -> string throws int { { : string | "caught string" catch (MayFailIntOrString(x) : string) : unknown catch (e) diff --git a/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__04_tir.snap b/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__04_tir.snap index 0415eec479..d81cac5941 100644 --- a/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__04_tir.snap @@ -2,7 +2,7 @@ source: crates/baml_tests/src/generated_tests.rs --- === TIR2 === -function user.MayFail(x: int) -> string throws never { +function user.MayFail(x: int) -> string throws string { { : "ok" match (x : int) : "ok" 0 => @@ -11,7 +11,7 @@ function user.MayFail(x: int) -> string throws never { "ok" : "ok" } } -function user.MayFailIntOrString(x: int) -> string throws never { +function user.MayFailIntOrString(x: int) -> string throws int | string { { : "ok" match (x : int) : "ok" 0 => @@ -22,7 +22,7 @@ function user.MayFailIntOrString(x: int) -> string throws never { "ok" : "ok" } } -function user.AlsoMayFail(x: int) -> string throws never { +function user.AlsoMayFail(x: int) -> string throws string { { : "also ok" match (x : int) : "also ok" 0 => @@ -89,7 +89,7 @@ function user.NestedCatch(x: int) -> string throws never { inner : string } } -function user.CatchAndRethrow(x: int) -> string throws never { +function user.CatchAndRethrow(x: int) -> string throws string { { : string catch (MayFail(x) : string) : unknown catch (e) @@ -97,7 +97,7 @@ function user.CatchAndRethrow(x: int) -> string throws never { throw e : never } } -function user.ThrowStatement(x: int) -> string throws never { +function user.ThrowStatement(x: int) -> string throws string { { : never throw "error" : "error" "unreachable" : unknown diff --git a/baml_language/crates/baml_tests/snapshots/catch_throw_regressions/baml_tests__catch_throw_regressions__04_tir.snap b/baml_language/crates/baml_tests/snapshots/catch_throw_regressions/baml_tests__catch_throw_regressions__04_tir.snap index fe4dfed276..58c4607f6a 100644 --- a/baml_language/crates/baml_tests/snapshots/catch_throw_regressions/baml_tests__catch_throw_regressions__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/catch_throw_regressions/baml_tests__catch_throw_regressions__04_tir.snap @@ -3,7 +3,7 @@ source: crates/baml_tests/src/generated_tests.rs --- === TIR2 === enum user.Status -function user.AlwaysThrowsStatus(n: int) -> user.Status throws never { +function user.AlwaysThrowsStatus(n: int) -> user.Status throws int | string { { : never match (n : int) : never 0 => @@ -12,7 +12,7 @@ function user.AlwaysThrowsStatus(n: int) -> user.Status throws never { throw 404 : never } } -function user.WrapperAlwaysThrows(n: int) -> user.Status throws never { +function user.WrapperAlwaysThrows(n: int) -> user.Status throws int | string { { : user.Status AlwaysThrowsStatus(n) : user.Status } @@ -29,7 +29,7 @@ function user.CatchAlwaysThrows(n: int) -> string throws never { } !! 622..644: type mismatch: expected string, got user.Status } -function user.ThrowsAllStatusVariants(n: int) -> user.Status throws never { +function user.ThrowsAllStatusVariants(n: int) -> user.Status throws user.Status.AuthError | user.Status.HttpError | user.Status.IndexError | user.Status.SomeOtherError { { : never match (n : int) : never 1 => @@ -60,7 +60,7 @@ function user.CatchAllStatusVariants(n: int) -> string throws never { } !! 1115..1141: type mismatch: expected string, got user.Status } -function user.RecA(n: int) -> string throws never { +function user.RecA(n: int) -> string throws string { { : string match (n : int) : string 0 => @@ -69,7 +69,7 @@ function user.RecA(n: int) -> string throws never { RecB(n) : string } } -function user.RecB(n: int) -> string throws never { +function user.RecB(n: int) -> string throws string { { : string RecA(n) : string } @@ -82,7 +82,7 @@ function user.CatchRecursiveThrow(n: int) -> string throws never { "handled" : "handled" } } -function user.UnreachableTailStillChecked() -> string throws never { +function user.UnreachableTailStillChecked() -> string throws string { { : never { : never throw "error" : "error" diff --git a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap index 4034704ead..9dba9be02e 100644 --- a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap @@ -670,7 +670,7 @@ function user.DeepFnReturn() -> (x: int) -> (y: int) -> (z: int) -> (w: int) -> } !! 2354..2355: type mismatch: expected (x: int) -> (y: int) -> (z: int) -> (w: int) -> map, got 1 } -function user.DeepParamTypes(nested: map>>, callback: (x: int) -> (y: int) -> (z: int) -> string) -> string throws never { +function user.DeepParamTypes(nested: map>>, callback: (x: int) -> (y: int) -> (z: int) -> string throws __throws_callback) -> string throws __throws_callback { { : "done" "done" : "done" } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__array_map_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__array_map_throws.snap new file mode 100644 index 0000000000..c8e8fe73d9 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__array_map_throws.snap @@ -0,0 +1,114 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Array" +Dot "." +Word "map" +Word "with" +Word "throwing" +Word "callback" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Pure" +Word "map" +Word "callback" +Function "function" +Word "test_map_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBracket "[" +RBracket "]" +LBrace "{" +Let "let" +Word "items" +Colon ":" +Word "int" +LBracket "[" +RBracket "]" +Equals "=" +LBracket "[" +IntegerLiteral "1" +Comma "," +IntegerLiteral "2" +Comma "," +IntegerLiteral "3" +RBracket "]" +Word "items" +Dot "." +Word "map" +LParen "(" +LParen "(" +Word "x" +RParen ")" +Arrow "->" +LBrace "{" +Word "x" +Star "*" +IntegerLiteral "2" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Throwing" +Word "map" +Word "callback" +Function "function" +Word "test_map_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBracket "[" +RBracket "]" +LBrace "{" +Let "let" +Word "items" +Colon ":" +Word "int" +LBracket "[" +RBracket "]" +Equals "=" +LBracket "[" +IntegerLiteral "1" +Comma "," +IntegerLiteral "2" +Comma "," +IntegerLiteral "3" +RBracket "]" +Word "items" +Dot "." +Word "map" +LParen "(" +LParen "(" +Word "x" +RParen ")" +Arrow "->" +LBrace "{" +If "if" +LParen "(" +Word "x" +EqualsEquals "==" +IntegerLiteral "2" +RParen ")" +LBrace "{" +Throw "throw" +Quote "\"" +Word "found" +Word "two" +Quote "\"" +RBrace "}" +Word "x" +Star "*" +IntegerLiteral "2" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__catch_absorbs_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__catch_absorbs_throws.snap new file mode 100644 index 0000000000..6d68ac7c0f --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__catch_absorbs_throws.snap @@ -0,0 +1,166 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Catch "catch" +Word "should" +Word "absorb" +Throws "throws" +EqualsEquals "==" +Equals "=" +Function "function" +Word "may_fail" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +Throws "throws" +Word "string" +LBrace "{" +If "if" +LParen "(" +Word "x" +EqualsEquals "==" +IntegerLiteral "0" +RParen ")" +LBrace "{" +Throw "throw" +Quote "\"" +Word "zero" +Quote "\"" +RBrace "}" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +Slash "/" +Slash "/" +Word "Calling" +Word "a" +Word "throwing" +Function "function" +Word "without" +Catch "catch" +Minus "-" +Word "should" +Word "propagate" +Throws "throws" +Function "function" +Word "test_no_catch" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "may_fail" +LParen "(" +IntegerLiteral "1" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Catching" +Word "with" +Word "wildcard" +Minus "-" +Word "should" +Word "absorb" +Throws "throws" +Function "function" +Word "test_catch_all" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "may_fail" +LParen "(" +IntegerLiteral "1" +RParen ")" +Catch "catch" +LParen "(" +Word "e" +RParen ")" +LBrace "{" +Word "_" +FatArrow "=>" +Quote "\"" +Word "caught" +Quote "\"" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +Word "Catching" +Word "specific" +Word "type" +Minus "-" +Word "should" +Word "absorb" +Word "that" +Word "type" +Function "function" +Word "test_catch_string" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "may_fail" +LParen "(" +IntegerLiteral "1" +RParen ")" +Catch "catch" +LParen "(" +Word "e" +RParen ")" +LBrace "{" +Word "_" +Colon ":" +Word "string" +FatArrow "=>" +Quote "\"" +Word "caught" +Word "string" +Quote "\"" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +Word "Catch" +Word "then" +Word "re-throw" +Minus "-" +Word "should" +Word "propagate" +Word "the" +Word "re-thrown" +Word "type" +Function "function" +Word "test_catch_rethrow" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "may_fail" +LParen "(" +IntegerLiteral "1" +RParen ")" +Catch "catch" +LParen "(" +Word "e" +RParen ")" +LBrace "{" +Word "_" +FatArrow "=>" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__chained_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__chained_hof.snap new file mode 100644 index 0000000000..20f293ced9 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__chained_hof.snap @@ -0,0 +1,102 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Chained" +Word "HOF" +Colon ":" +Throws "throws" +Word "through" +Word "multiple" +Word "layers" +EqualsEquals "==" +Equals "=" +Function "function" +Word "apply_inner" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "apply_outer" +LParen "(" +Word "g" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_inner" +LParen "(" +Word "g" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Pure" +Word "through" +Word "two" +Word "layers" +Function "function" +Word "test_chained_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_outer" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Throwing" +Word "through" +Word "two" +Word "layers" +Function "function" +Word "test_chained_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_outer" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "deep" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__explicit_param_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__explicit_param_throws.snap new file mode 100644 index 0000000000..cb05bcd1f1 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__explicit_param_throws.snap @@ -0,0 +1,140 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Explicit" +Throws "throws" +Word "on" +Word "callback" +Word "param" +Word "should" +Word "NOT" +Word "create" +Word "effect" +Word "var" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "When" +Word "the" +Word "callback" +Word "param" +Word "has" +Word "explicit" +Throws "throws" +Comma "," +Word "no" +Word "__throws_" +Word "var" +Word "should" +Word "appear" +Function "function" +Word "apply_explicit" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_explicit_param_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_explicit" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" +Function "function" +Word "test_explicit_param_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_explicit" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Explicit" +Throws "throws" +Word "never" +Word "on" +Word "param" +Minus "-" +Word "pure" +Word "callbacks" +Word "only" +Function "function" +Word "apply_pure_only" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "never" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_pure_only_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_pure_only" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_own_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_own_throws.snap new file mode 100644 index 0000000000..6dd1d5a739 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_own_throws.snap @@ -0,0 +1,144 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "HOF" +Word "with" +Word "its" +Word "own" +Throws "throws" +Plus "+" +Word "callback" +Throws "throws" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "A" +Function "function" +Word "that" +Word "both" +Throws "throws" +Word "itself" +Word "AND" +Word "rethrows" +Word "callback" +Word "effects" +Dot "." +Slash "/" +Slash "/" +Word "The" +Word "effective" +Throws "throws" +Word "should" +Word "be" +Word "the" +Word "UNION" +Word "of" +Word "both" +Dot "." +Function "function" +Word "apply_guarded" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "result" +Equals "=" +Word "f" +LParen "(" +RParen ")" +If "if" +LParen "(" +Word "result" +Less "<" +IntegerLiteral "0" +RParen ")" +LBrace "{" +Throw "throw" +Quote "\"" +Word "negative" +Word "result" +Quote "\"" +RBrace "}" +Word "result" +RBrace "}" +Slash "/" +Slash "/" +Word "Pure" +Word "callback" +Colon ":" +Function "function" +Error "'" +Word "s" +Word "own" +Throw "throw" +Word "should" +Word "still" +Word "appear" +Function "function" +Word "test_guarded_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_guarded" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Throwing" +Word "callback" +Colon ":" +Word "both" +Word "the" +Function "function" +Error "'" +Word "s" +Word "own" +Throw "throw" +Word "and" +Word "the" +Word "callback" +Error "'" +Word "s" +Throw "throw" +Function "function" +Word "test_guarded_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_guarded" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +IntegerLiteral "99" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_rethrows.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_rethrows.snap new file mode 100644 index 0000000000..b0d0ceed85 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_rethrows.snap @@ -0,0 +1,304 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "HOF" +Word "rethrows" +Colon ":" +Word "effect" +Word "propagation" +Word "through" +Word "higher-order" +Word "functions" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Basic" +Word "rethrows" +Minus "-" +Word "pure" +Word "callback" +Function "function" +Word "run_pure" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_run_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "run_pure" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Basic" +Word "rethrows" +Minus "-" +Word "throwing" +Word "callback" +Function "function" +Word "run_throwing" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_run_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "run_throwing" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "boom" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Rethrows" +Word "with" +Word "multiple" +Word "callback" +Word "params" +Function "function" +Word "run_two" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Comma "," +Word "g" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +Plus "+" +Word "g" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_two_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "run_two" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "1" +RBrace "}" +Comma "," +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "2" +RBrace "}" +RParen ")" +RBrace "}" +Function "function" +Word "test_two_one_throws" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "run_two" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "1" +RBrace "}" +Comma "," +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "boom" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Function "function" +Word "test_two_both_throw" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "run_two" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "a" +Quote "\"" +RBrace "}" +Comma "," +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Rethrows" +Word "with" +Word "parametric" +Word "callback" +Function "function" +Word "map_it" +LParen "(" +Word "x" +Colon ":" +Word "int" +Comma "," +Word "f" +Colon ":" +LParen "(" +Word "int" +RParen ")" +Arrow "->" +Word "string" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "f" +LParen "(" +Word "x" +RParen ")" +RBrace "}" +Function "function" +Word "test_map_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "map_it" +LParen "(" +IntegerLiteral "1" +Comma "," +LParen "(" +Word "n" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Function "function" +Word "test_map_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "map_it" +LParen "(" +IntegerLiteral "1" +Comma "," +LParen "(" +Word "n" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Throw "throw" +Quote "\"" +Word "bad" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap index 73a4493d97..326e0dc631 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap @@ -207,3 +207,102 @@ IntegerLiteral "42" RBrace "}" RParen ")" RBrace "}" +Slash "/" +Slash "/" +Word "Function" +Word "with" +Word "omitted" +Throws "throws" +Word "that" +Word "still" +Throws "throws" +Word "from" +Word "its" +Word "body" +Dot "." +Function "function" +Word "helper_with_body_throw" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "helper" +Word "boom" +Quote "\"" +RBrace "}" +Slash "/" +Slash "/" +Word "Wrapper" +Word "should" +Word "keep" +Word "the" +Word "helper" +Error "'" +Word "s" +Word "concrete" +Word "body" +Throw "throw" +Word "alongside" +Word "callback" +Word "effects" +Dot "." +Function "function" +Word "apply_with_helper" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "helper_with_body_throw" +LParen "(" +RParen ")" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_apply_with_helper_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_with_helper" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" +Function "function" +Word "test_apply_with_helper_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_with_helper" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__mixed_params.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__mixed_params.snap new file mode 100644 index 0000000000..6c3b7d08c2 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__mixed_params.snap @@ -0,0 +1,184 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Mixed" +Colon ":" +Function "function" +Word "with" +Word "both" +Function "function" +Word "and" +Word "non-function" +Word "params" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Only" +Word "the" +Function "function" +Word "param" +Word "should" +Word "get" +Word "an" +Word "effect" +Word "var" +Function "function" +Word "apply_with_arg" +LParen "(" +Word "x" +Colon ":" +Word "int" +Comma "," +Word "f" +Colon ":" +LParen "(" +Word "int" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +Word "x" +RParen ")" +RBrace "}" +Function "function" +Word "test_mixed_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_with_arg" +LParen "(" +IntegerLiteral "5" +Comma "," +LParen "(" +Word "n" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "n" +Star "*" +IntegerLiteral "2" +RBrace "}" +RParen ")" +RBrace "}" +Function "function" +Word "test_mixed_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_with_arg" +LParen "(" +IntegerLiteral "5" +Comma "," +LParen "(" +Word "n" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +If "if" +LParen "(" +Word "n" +Less "<" +IntegerLiteral "0" +RParen ")" +LBrace "{" +Throw "throw" +Quote "\"" +Word "negative" +Quote "\"" +RBrace "}" +Word "n" +Star "*" +IntegerLiteral "2" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Multiple" +Word "non-function" +Word "params" +Word "with" +Word "one" +Word "callback" +Function "function" +Word "apply_with_many_args" +LParen "(" +Word "a" +Colon ":" +Word "int" +Comma "," +Word "b" +Colon ":" +Word "string" +Comma "," +Word "f" +Colon ":" +LParen "(" +Word "int" +Comma "," +Word "string" +RParen ")" +Arrow "->" +Word "string" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "f" +LParen "(" +Word "a" +Comma "," +Word "b" +RParen ")" +RBrace "}" +Function "function" +Word "test_many_args_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "apply_with_many_args" +LParen "(" +IntegerLiteral "1" +Comma "," +Quote "\"" +Word "hello" +Quote "\"" +Comma "," +LParen "(" +Word "n" +Colon ":" +Word "int" +Comma "," +Word "s" +Colon ":" +Word "string" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "s" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__returned_closures.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__returned_closures.snap new file mode 100644 index 0000000000..374c9cb45c --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__returned_closures.snap @@ -0,0 +1,120 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Returned" +Word "closures" +Colon ":" +Word "should" +Word "default" +Word "to" +Throws "throws" +Word "never" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Returning" +Word "a" +Word "pure" +Word "closure" +Minus "-" +Word "should" +Word "be" +Word "fine" +Function "function" +Word "make_pure" +LParen "(" +RParen ")" +Arrow "->" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +LBrace "{" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +Word "Returning" +Word "a" +Word "closure" +Word "with" +Word "explicit" +Throws "throws" +Function "function" +Word "make_thrower" +LParen "(" +RParen ")" +Arrow "->" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +RParen ")" +LBrace "{" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "error" +Quote "\"" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +Word "Using" +Word "a" +Word "returned" +Word "closure" +Function "function" +Word "test_use_pure" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +Word "make_pure" +LParen "(" +RParen ")" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_use_thrower" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +Word "make_thrower" +LParen "(" +RParen ")" +Word "f" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__array_map_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__array_map_throws.snap new file mode 100644 index 0000000000..53b8a0c345 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__array_map_throws.snap @@ -0,0 +1,136 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_map_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int[]" + WORD "int" + L_BRACKET "[" + R_BRACKET "]" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "items" + COLON ":" + TYPE_EXPR "int[]" + WORD "int" + L_BRACKET "[" + R_BRACKET "]" + EQUALS "=" + ARRAY_LITERAL "[1, 2, 3]" + L_BRACKET "[" + INTEGER_LITERAL "1" + COMMA "," + INTEGER_LITERAL "2" + COMMA "," + INTEGER_LITERAL "3" + R_BRACKET "]" + CALL_EXPR + PATH_EXPR "items.map" + WORD "items" + DOT "." + WORD "map" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER "x" + WORD "x" + R_PAREN ")" + ARROW "->" + BLOCK_EXPR + L_BRACE "{" + BINARY_EXPR "x * 2" + WORD "x" + STAR "*" + INTEGER_LITERAL "2" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_map_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int[]" + WORD "int" + L_BRACKET "[" + R_BRACKET "]" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "items" + COLON ":" + TYPE_EXPR "int[]" + WORD "int" + L_BRACKET "[" + R_BRACKET "]" + EQUALS "=" + ARRAY_LITERAL "[1, 2, 3]" + L_BRACKET "[" + INTEGER_LITERAL "1" + COMMA "," + INTEGER_LITERAL "2" + COMMA "," + INTEGER_LITERAL "3" + R_BRACKET "]" + CALL_EXPR + PATH_EXPR "items.map" + WORD "items" + DOT "." + WORD "map" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER "x" + WORD "x" + R_PAREN ")" + ARROW "->" + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "x == 2" + WORD "x" + EQUALS_EQUALS "==" + INTEGER_LITERAL "2" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "found two" + QUOTE """ + WORD "found" + WORD "two" + QUOTE """ + R_BRACE "}" + BINARY_EXPR "x * 2" + WORD "x" + STAR "*" + INTEGER_LITERAL "2" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__catch_absorbs_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__catch_absorbs_throws.snap new file mode 100644 index 0000000000..ddc109d949 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__catch_absorbs_throws.snap @@ -0,0 +1,183 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "may_fail" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "x == 0" + WORD "x" + EQUALS_EQUALS "==" + INTEGER_LITERAL "0" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "zero" + QUOTE """ + WORD "zero" + QUOTE """ + R_BRACE "}" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_no_catch" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "may_fail" + CALL_ARGS "(1)" + L_PAREN "(" + INTEGER_LITERAL "1" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_catch_all" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CATCH_EXPR + CALL_EXPR + WORD "may_fail" + CALL_ARGS "(1)" + L_PAREN "(" + INTEGER_LITERAL "1" + R_PAREN ")" + CATCH_CLAUSE + KW_CATCH "catch" + L_PAREN "(" + CATCH_PATTERN "e" + WORD "e" + R_PAREN ")" + L_BRACE "{" + CATCH_ARM + CATCH_PATTERN "_" + WORD "_" + FAT_ARROW "=>" + STRING_LITERAL "caught" + QUOTE """ + WORD "caught" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_catch_string" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CATCH_EXPR + CALL_EXPR + WORD "may_fail" + CALL_ARGS "(1)" + L_PAREN "(" + INTEGER_LITERAL "1" + R_PAREN ")" + CATCH_CLAUSE + KW_CATCH "catch" + L_PAREN "(" + CATCH_PATTERN "e" + WORD "e" + R_PAREN ")" + L_BRACE "{" + CATCH_ARM + CATCH_PATTERN + WORD "_" + COLON ":" + TYPE_EXPR "string" + WORD "string" + FAT_ARROW "=>" + STRING_LITERAL "caught string" + QUOTE """ + WORD "caught" + WORD "string" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_catch_rethrow" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CATCH_EXPR + CALL_EXPR + WORD "may_fail" + CALL_ARGS "(1)" + L_PAREN "(" + INTEGER_LITERAL "1" + R_PAREN ")" + CATCH_CLAUSE + KW_CATCH "catch" + L_PAREN "(" + CATCH_PATTERN "e" + WORD "e" + R_PAREN ")" + L_BRACE "{" + CATCH_ARM + CATCH_PATTERN "_" + WORD "_" + FAT_ARROW "=>" + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__chained_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__chained_hof.snap new file mode 100644 index 0000000000..4dce51bd92 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__chained_hof.snap @@ -0,0 +1,127 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_inner" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_outer" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "g" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_inner" + CALL_ARGS "(g)" + L_PAREN "(" + WORD "g" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_chained_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_outer" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_chained_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_outer" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "deep" + QUOTE """ + WORD "deep" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__explicit_param_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__explicit_param_throws.snap new file mode 100644 index 0000000000..36cb161028 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__explicit_param_throws.snap @@ -0,0 +1,163 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_explicit" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_explicit_param_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_explicit" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_explicit_param_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_explicit" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_pure_only" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "never" + WORD "never" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_pure_only_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_pure_only" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_own_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_own_throws.snap new file mode 100644 index 0000000000..bc49464f32 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_own_throws.snap @@ -0,0 +1,121 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_guarded" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "result" + EQUALS "=" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "result < 0" + WORD "result" + LESS "<" + INTEGER_LITERAL "0" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "negative result" + QUOTE """ + WORD "negative" + WORD "result" + QUOTE """ + R_BRACE "}" + WORD "result" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_guarded_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_guarded" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_guarded_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_guarded" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 99" + KW_THROW "throw" + INTEGER_LITERAL "99" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_rethrows.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_rethrows.snap new file mode 100644 index 0000000000..6f5e45b250 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_rethrows.snap @@ -0,0 +1,426 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "run_pure" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_run_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "run_pure" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "run_throwing" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_run_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "run_throwing" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "boom" + QUOTE """ + WORD "boom" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "run_two" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + COMMA "," + PARAMETER + WORD "g" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + BINARY_EXPR + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + PLUS "+" + CALL_EXPR + WORD "g" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_two_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "run_two" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 1 }" + L_BRACE "{" + INTEGER_LITERAL "1" + R_BRACE "}" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 2 }" + L_BRACE "{" + INTEGER_LITERAL "2" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_two_one_throws" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "run_two" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 1 }" + L_BRACE "{" + INTEGER_LITERAL "1" + R_BRACE "}" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "boom" + QUOTE """ + WORD "boom" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_two_both_throw" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "run_two" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "a" + QUOTE """ + WORD "a" + QUOTE """ + R_BRACE "}" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "map_it" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + COMMA "," + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "(x)" + L_PAREN "(" + WORD "x" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_map_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "map_it" + CALL_ARGS + L_PAREN "(" + INTEGER_LITERAL "1" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "n" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR + L_BRACE "{" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_map_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "map_it" + CALL_ARGS + L_PAREN "(" + INTEGER_LITERAL "1" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "n" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "bad" + QUOTE """ + WORD "bad" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap index 1c7e5a3e33..ecfdd7c114 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap @@ -241,6 +241,120 @@ SOURCE_FILE R_BRACE "}" R_PAREN ")" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "helper_with_body_throw" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "helper boom" + QUOTE """ + WORD "helper" + WORD "boom" + QUOTE """ + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_with_helper" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "helper_with_body_throw" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_apply_with_helper_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_with_helper" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_apply_with_helper_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_with_helper" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" === ERRORS === None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__mixed_params.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__mixed_params.snap new file mode 100644 index 0000000000..8106b64998 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__mixed_params.snap @@ -0,0 +1,238 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_with_arg" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + COMMA "," + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "(x)" + L_PAREN "(" + WORD "x" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_mixed_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_with_arg" + CALL_ARGS + L_PAREN "(" + INTEGER_LITERAL "5" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "n" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + BINARY_EXPR "n * 2" + WORD "n" + STAR "*" + INTEGER_LITERAL "2" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_mixed_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_with_arg" + CALL_ARGS + L_PAREN "(" + INTEGER_LITERAL "5" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "n" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "n < 0" + WORD "n" + LESS "<" + INTEGER_LITERAL "0" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "negative" + QUOTE """ + WORD "negative" + QUOTE """ + R_BRACE "}" + BINARY_EXPR "n * 2" + WORD "n" + STAR "*" + INTEGER_LITERAL "2" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_with_many_args" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "a" + COLON ":" + TYPE_EXPR "int" + WORD "int" + COMMA "," + PARAMETER + WORD "b" + COLON ":" + TYPE_EXPR "string" + WORD "string" + COMMA "," + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "int" + WORD "int" + COMMA "," + FUNCTION_TYPE_PARAM + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "(a, b)" + L_PAREN "(" + WORD "a" + COMMA "," + WORD "b" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_many_args_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_with_many_args" + CALL_ARGS + L_PAREN "(" + INTEGER_LITERAL "1" + COMMA "," + STRING_LITERAL "hello" + QUOTE """ + WORD "hello" + QUOTE """ + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "n" + COLON ":" + TYPE_EXPR "int" + WORD "int" + COMMA "," + PARAMETER + WORD "s" + COLON ":" + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR "{ s }" + L_BRACE "{" + WORD "s" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__returned_closures.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__returned_closures.snap new file mode 100644 index 0000000000..777f1272b9 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__returned_closures.snap @@ -0,0 +1,136 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_thrower" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "error" + QUOTE """ + WORD "error" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_use_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + CALL_EXPR + WORD "make_pure" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_use_thrower" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + CALL_EXPR + WORD "make_thrower" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index 6884cbbc6a..b2819f0858 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -2,10 +2,43 @@ source: crates/baml_tests/src/generated_tests.rs --- === HIR2 === +function user.test_map_pure() -> int[] [expr] { + { let items: int[] = [1, 2, 3] } items.map((x) -> { { } x Mul 2 }) +} +function user.test_map_throwing() -> int[] [expr] { + { let items: int[] = [1, 2, 3] } items.map((x) -> { { if (x Eq 2) { throw "found two" } } x Mul 2 }) +} type user.ThrowingFn = () -> int function user.takes_throwing(f: () -> int) -> int [expr] { { } f() } +function user.may_fail(x: int) -> string [expr] { + { if (x Eq 0) { throw "zero" } } "ok" +} +function user.test_catch_all() -> string [expr] { + { } may_fail(1) catch (e) { _ => "caught" } +} +function user.test_catch_rethrow() -> string [expr] { + { } may_fail(1) catch (e) { _ => throw 42 } +} +function user.test_catch_string() -> string [expr] { + { } may_fail(1) catch (e) { _: string => "caught string" } +} +function user.test_no_catch() -> string [expr] { + { } may_fail(1) +} +function user.apply_inner(f: () -> int) -> int [expr] { + { } f() +} +function user.apply_outer(g: () -> int) -> int [expr] { + { } apply_inner(g) +} +function user.test_chained_pure() -> int [expr] { + { } apply_outer(() -> int { { } 42 }) +} +function user.test_chained_throwing() -> int [expr] { + { } apply_outer(() -> int { { throw "deep" } }) +} class user.MixedHandler { safe: () -> int risky: () -> int @@ -22,6 +55,21 @@ function user.make_pure_handler() -> user.PureHandler [expr] { function user.make_throwing_handler() -> user.ThrowingHandler [expr] { { } user.ThrowingHandler { run: () -> null { { throw "error" } } } } +function user.apply_explicit(f: () -> int) -> int [expr] { + { } f() +} +function user.apply_pure_only(f: () -> int) -> int [expr] { + { } f() +} +function user.test_explicit_param_pure() -> int [expr] { + { } apply_explicit(() -> int { { } 42 }) +} +function user.test_explicit_param_throwing() -> int [expr] { + { } apply_explicit(() -> int { { throw "ok" } }) +} +function user.test_pure_only_pure() -> int [expr] { + { } apply_pure_only(() -> int { { } 42 }) +} function user.always_ok(x: int) -> string [expr] { { } "always ok" } @@ -39,6 +87,48 @@ type user.Mapper = (int) -> string type user.PureCallback = () -> int type user.ThrowingCallback = () -> int type user.Wrapper = (() -> int) -> int +function user.apply_guarded(f: () -> int) -> int [expr] { + { let result = f(); if (result Lt 0) { throw "negative result" } } result +} +function user.test_guarded_pure() -> int [expr] { + { } apply_guarded(() -> int { { } 42 }) +} +function user.test_guarded_throwing() -> int [expr] { + { } apply_guarded(() -> int { { throw 99 } }) +} +function user.map_it(x: int, f: (int) -> string) -> string [expr] { + { } f(x) +} +function user.run_pure(f: () -> int) -> int [expr] { + { } f() +} +function user.run_throwing(f: () -> int) -> int [expr] { + { } f() +} +function user.run_two(f: () -> int, g: () -> int) -> int [expr] { + { } f() Add g() +} +function user.test_map_pure() -> string [expr] { + { } map_it(1, (n: int) -> string { { } "ok" }) +} +function user.test_map_throwing() -> string [expr] { + { } map_it(1, (n: int) -> string { { throw "bad" } }) +} +function user.test_run_pure() -> int [expr] { + { } run_pure(() -> int { { } 42 }) +} +function user.test_run_throwing() -> int [expr] { + { } run_throwing(() -> int { { throw "boom" } }) +} +function user.test_two_both_throw() -> int [expr] { + { } run_two(() -> int { { throw "a" } }, () -> int { { throw 42 } }) +} +function user.test_two_one_throws() -> int [expr] { + { } run_two(() -> int { { } 1 }, () -> int { { throw "boom" } }) +} +function user.test_two_pure() -> int [expr] { + { } run_two(() -> int { { } 1 }, () -> int { { } 2 }) +} function user.apply(f: () -> int) -> int [expr] { { } f() } @@ -48,6 +138,12 @@ function user.apply_and_throw(f: () -> int) -> int [expr] { function user.apply_throwing(f: () -> int) -> int [expr] { { } f() } +function user.apply_with_helper(f: () -> int) -> int [expr] { + { helper_with_body_throw() } f() +} +function user.helper_with_body_throw() -> int [expr] { + { throw "helper boom" } +} function user.test_apply_and_throw_pure() -> int [expr] { { } apply_and_throw(() -> int { { } 42 }) } @@ -60,6 +156,12 @@ function user.test_apply_pure() -> int [expr] { function user.test_apply_throwing() -> int [expr] { { } apply(() -> int { { throw "boom" } }) } +function user.test_apply_with_helper_pure() -> int [expr] { + { } apply_with_helper(() -> int { { } 42 }) +} +function user.test_apply_with_helper_throwing() -> int [expr] { + { } apply_with_helper(() -> int { { throw 42 } }) +} function user.test_explicit_throws_match() -> int [expr] { { let f = () -> int throws string { { throw "boom" } } } f() } @@ -87,6 +189,21 @@ function user.test_throwing_lambda() -> int [expr] { function user.test_throws_never_but_throws() -> int [expr] { { let f = () -> int throws never { { throw "boom" } } } f() } +function user.apply_with_arg(x: int, f: (int) -> int) -> int [expr] { + { } f(x) +} +function user.apply_with_many_args(a: int, b: string, f: (int, string) -> string) -> string [expr] { + { } f(a, b) +} +function user.test_many_args_pure() -> string [expr] { + { } apply_with_many_args(1, "hello", (n: int, s: string) -> string { { } s }) +} +function user.test_mixed_pure() -> int [expr] { + { } apply_with_arg(5, (n: int) -> int { { } n Mul 2 }) +} +function user.test_mixed_throwing() -> int [expr] { + { } apply_with_arg(5, (n: int) -> int { { if (n Lt 0) { throw "negative" } } n Mul 2 }) +} function user.test_nested_both_throw() -> int [expr] { { let outer = () -> int { { let inner = () -> int { { throw 42 } }; inner(); throw "outer" } } } outer() } @@ -96,3 +213,15 @@ function user.test_nested_inner_throws() -> int [expr] { function user.test_nested_outer_throws() -> int [expr] { { let outer = () -> int { { let inner = () -> int { { } 42 }; throw "outer boom" } } } outer() } +function user.make_pure() -> () -> int [expr] { + { } +} +function user.make_thrower() -> () -> int [expr] { + { } +} +function user.test_use_pure() -> int [expr] { + { let f = make_pure() } f() +} +function user.test_use_thrower() -> int [expr] { + { let f = make_thrower() } f() +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index 31f87a3c1f..a5969d347e 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -2,47 +2,860 @@ source: crates/baml_tests/src/generated_tests.rs --- === MIR2 === +fn user.test_map_pure() -> int[] { + // Locals: + let _0: int[] // _0 // return + let _1: int[] // items + let _2: int[] + let _3: (int) -> int + + bb0: { + _1 = [const 1_i64, const 2_i64, const 3_i64]; + _2 = copy _1; + _3 = make_closure lambda[0](); + _0 = call const fn baml.Array.map(copy _2, copy _3) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(x: null) -> null { + // Locals: + let _0: null // _0 // return + let _1: null // x // param + + bb0: { + _0 = copy _1 * const 2_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_map_throwing() -> int[] { + // Locals: + let _0: int[] // _0 // return + let _1: int[] // items + let _2: int[] + let _3: (int) -> int + + bb0: { + _1 = [const 1_i64, const 2_i64, const 3_i64]; + _2 = copy _1; + _3 = make_closure lambda[0](); + _0 = call const fn baml.Array.map(copy _2, copy _3) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(x: null) -> null { + // Locals: + let _0: null // _0 // return + let _1: null // x // param + let _2: bool + + bb0: { + _2 = copy _1 == const 2_i64; + branch copy _2 -> [bb4, bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _0 = copy _1 * const 2_i64; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + throw const "found two"; + } +} + fn user.takes_throwing(f: () -> int) -> int { // Locals: let _0: int // _0 // return let _1: () -> int // f // param bb0: { - _0 = call copy _1() -> [bb1]; + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.may_fail(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: bool + + bb0: { + _2 = copy _1 == const 0_i64; + branch copy _2 -> [bb4, bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _0 = const "ok"; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + throw const "zero"; + } +} + +fn user.test_catch_all() -> string { + // Locals: + let _0: string // _0 // return + let _1: unknown // e + + bb0: { + _0 = call const fn user.may_fail(const 1_i64) -> [bb1, unwind: bb2]; + } + + bb1: { + goto -> bb5; + } + + bb2: { + throw_if_panic copy _1 -> bb3; + } + + bb3: { + goto -> bb4; + } + + bb4: { + _0 = const "caught"; + goto -> bb5; + } + + bb5: { + goto -> bb6; + } + + bb6: { + return; + } +} + +fn user.test_catch_rethrow() -> string { + // Locals: + let _0: string // _0 // return + let _1: unknown // e + + bb0: { + _0 = call const fn user.may_fail(const 1_i64) -> [bb1, unwind: bb4]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + throw_if_panic copy _1 -> bb5; + } + + bb5: { + goto -> bb6; + } + + bb6: { + throw const 42_i64; + } +} + +fn user.test_catch_string() -> string { + // Locals: + let _0: string // _0 // return + let _1: unknown // e + let _2: bool + + bb0: { + _0 = call const fn user.may_fail(const 1_i64) -> [bb1, unwind: bb2]; + } + + bb1: { + goto -> bb5; + } + + bb2: { + _2 = is_type(copy _1, String { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + branch copy _2 -> [bb4, bb3]; + } + + bb3: { + throw copy _1; + } + + bb4: { + _0 = const "caught string"; + goto -> bb5; + } + + bb5: { + goto -> bb6; + } + + bb6: { + return; + } +} + +fn user.test_no_catch() -> string { + // Locals: + let _0: string // _0 // return + + bb0: { + _0 = call const fn user.may_fail(const 1_i64) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.apply_inner(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.apply_outer(g: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // g // param + + bb0: { + _0 = call const fn user.apply_inner(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.test_chained_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 42 + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_outer(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_chained_throwing() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_outer(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "deep"; + } +} + +fn user.make_pure_handler() -> PureHandler { + // Locals: + let _0: PureHandler // _0 // return + let _1: () -> null + + bb0: { + _1 = make_closure lambda[0](); + _0 = PureHandler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.make_throwing_handler() -> ThrowingHandler { + // Locals: + let _0: ThrowingHandler // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = ThrowingHandler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "error"; + } +} + +fn user.apply_explicit(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.apply_pure_only(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.test_explicit_param_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 42 + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_explicit(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_explicit_param_throwing() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_explicit(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "ok"; + } +} + +fn user.test_pure_only_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 42 + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_pure_only(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.always_ok(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + + bb0: { + _0 = const "always ok"; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.caller() -> string { + // Locals: + let _0: string // _0 // return + + bb0: { + _0 = call const fn user.may_fail(const 1_i64) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.may_fail(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: bool + + bb0: { + _2 = copy _1 == const 0_i64; + branch copy _2 -> [bb4, bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _0 = const "ok"; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + throw const "zero"; + } +} + +fn user.safe_caller() -> string { + // Locals: + let _0: string // _0 // return + let _1: unknown // e + + bb0: { + _0 = call const fn user.may_fail(const 1_i64) -> [bb1, unwind: bb2]; + } + + bb1: { + goto -> bb5; + } + + bb2: { + throw_if_panic copy _1 -> bb3; + } + + bb3: { + goto -> bb4; + } + + bb4: { + _0 = const "caught"; + goto -> bb5; + } + + bb5: { + goto -> bb6; + } + + bb6: { + return; + } +} + +fn user.apply_guarded(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + let _2: int // result + let _3: bool + let _4: int + + bb0: { + _2 = call copy _1() -> [bb1]; + } + + bb1: { + _4 = copy _2; + _3 = copy _4 < const 0_i64; + branch copy _3 -> [bb5, bb2]; + } + + bb2: { + goto -> bb3; + } + + bb3: { + _0 = copy _2; + goto -> bb4; + } + + bb4: { + return; + } + + bb5: { + throw const "negative result"; + } +} + +fn user.test_guarded_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 42 + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_guarded(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_guarded_throwing() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_guarded(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const 99_i64; + } +} + +fn user.map_it(x: int, f: (int) -> string) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: (int) -> string // f // param + + bb0: { + _0 = call copy _2(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.run_pure(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.run_throwing(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.run_two(f: () -> int, g: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + let _2: () -> int // g // param + let _3: int + let _4: int + + bb0: { + _3 = call copy _1() -> [bb1]; } bb1: { - goto -> bb2; + _4 = call copy _2() -> [bb2]; } bb2: { + _0 = copy _3 + copy _4; + goto -> bb3; + } + + bb3: { return; } } -fn user.make_pure_handler() -> PureHandler { +fn user.test_map_pure() -> string { // Locals: - let _0: PureHandler // _0 // return - let _1: () -> null + let _0: string // _0 // return + let _1: (int) -> "ok" bb0: { _1 = make_closure lambda[0](); - _0 = PureHandler { copy _1 }; - goto -> bb1; + _0 = call const fn user.map_it(const 1_i64, copy _1) -> [bb1]; } bb1: { + goto -> bb2; + } + + bb2: { return; } } // lambda[0] -fn .() -> null { +fn .(n: int) -> null { // Locals: let _0: null // _0 // return + let _1: int // n // param bb0: { - _0 = const null; + _0 = const "ok"; goto -> bb1; } @@ -51,39 +864,62 @@ fn .() -> null { } } -fn user.make_throwing_handler() -> ThrowingHandler { +fn user.test_map_throwing() -> string { // Locals: - let _0: ThrowingHandler // _0 // return - let _1: () -> void + let _0: string // _0 // return + let _1: (int) -> void bb0: { _1 = make_closure lambda[0](); - _0 = ThrowingHandler { copy _1 }; - goto -> bb1; + _0 = call const fn user.map_it(const 1_i64, copy _1) -> [bb1]; } bb1: { + goto -> bb2; + } + + bb2: { return; } } // lambda[0] -fn .() -> null { +fn .(n: int) -> null { // Locals: let _0: null // _0 // return + let _1: int // n // param bb0: { - throw const "error"; + throw const "bad"; } } -fn user.always_ok(x: int) -> string { +fn user.test_run_pure() -> int { // Locals: - let _0: string // _0 // return - let _1: int // x // param + let _0: int // _0 // return + let _1: () -> 42 bb0: { - _0 = const "always ok"; + _1 = make_closure lambda[0](); + _0 = call const fn user.run_pure(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; goto -> bb1; } @@ -92,12 +928,14 @@ fn user.always_ok(x: int) -> string { } } -fn user.caller() -> string { +fn user.test_run_throwing() -> int { // Locals: - let _0: string // _0 // return + let _0: int // _0 // return + let _1: () -> void bb0: { - _0 = call const fn user.may_fail(const 1_i64) -> [bb1]; + _1 = make_closure lambda[0](); + _0 = call const fn user.run_throwing(copy _1) -> [bb1]; } bb1: { @@ -109,15 +947,26 @@ fn user.caller() -> string { } } -fn user.may_fail(x: int) -> string { +// lambda[0] +fn .() -> null { // Locals: - let _0: string // _0 // return - let _1: int // x // param - let _2: bool + let _0: null // _0 // return bb0: { - _2 = copy _1 == const 0_i64; - branch copy _2 -> [bb4, bb1]; + throw const "boom"; + } +} + +fn user.test_two_both_throw() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = make_closure lambda[1](); + _0 = call const fn user.run_two(copy _1, copy _2) -> [bb1]; } bb1: { @@ -125,50 +974,123 @@ fn user.may_fail(x: int) -> string { } bb2: { - _0 = const "ok"; - goto -> bb3; + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "a"; } +} - bb3: { +// lambda[1] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const 42_i64; + } +} + +fn user.test_two_one_throws() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 1 + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = make_closure lambda[1](); + _0 = call const fn user.run_two(copy _1, copy _2) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { return; } +} - bb4: { - throw const "zero"; +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 1_i64; + goto -> bb1; + } + + bb1: { + return; } } -fn user.safe_caller() -> string { +// lambda[1] +fn .() -> null { // Locals: - let _0: string // _0 // return - let _1: unknown // e + let _0: null // _0 // return bb0: { - _0 = call const fn user.may_fail(const 1_i64) -> [bb1, unwind: bb2]; + throw const "boom"; + } +} + +fn user.test_two_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 1 + let _2: () -> 2 + + bb0: { + _1 = make_closure lambda[0](); + _2 = make_closure lambda[1](); + _0 = call const fn user.run_two(copy _1, copy _2) -> [bb1]; } bb1: { - goto -> bb5; + goto -> bb2; } bb2: { - throw_if_panic copy _1 -> bb3; + return; } +} - bb3: { - goto -> bb4; +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 1_i64; + goto -> bb1; } - bb4: { - _0 = const "caught"; - goto -> bb5; + bb1: { + return; } +} - bb5: { - goto -> bb6; +// lambda[1] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 2_i64; + goto -> bb1; } - bb6: { + bb1: { return; } } @@ -210,30 +1132,115 @@ fn user.apply_and_throw(f: () -> int) -> int { } bb2: { - goto -> bb3; + goto -> bb3; + } + + bb3: { + _0 = copy _2; + goto -> bb4; + } + + bb4: { + return; + } + + bb5: { + throw const "negative result"; + } +} + +fn user.apply_throwing(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.apply_with_helper(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + let _2: int + + bb0: { + _2 = call const fn user.helper_with_body_throw() -> [bb1]; + } + + bb1: { + _0 = call copy _1() -> [bb2]; + } + + bb2: { + goto -> bb3; + } + + bb3: { + return; + } +} + +fn user.helper_with_body_throw() -> int { + // Locals: + let _0: int // _0 // return + + bb0: { + throw const "helper boom"; + } +} + +fn user.test_apply_and_throw_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> 42 + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_and_throw(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; } +} - bb3: { - _0 = copy _2; - goto -> bb4; - } +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return - bb4: { - return; + bb0: { + _0 = const 42_i64; + goto -> bb1; } - bb5: { - throw const "negative result"; + bb1: { + return; } } -fn user.apply_throwing(f: () -> int) -> int { +fn user.test_apply_explicit_throws() -> int { // Locals: let _0: int // _0 // return - let _1: () -> int // f // param + let _1: () -> void bb0: { - _0 = call copy _1() -> [bb1]; + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_throwing(copy _1) -> [bb1]; } bb1: { @@ -245,14 +1252,24 @@ fn user.apply_throwing(f: () -> int) -> int { } } -fn user.test_apply_and_throw_pure() -> int { +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "boom"; + } +} + +fn user.test_apply_pure() -> int { // Locals: let _0: int // _0 // return let _1: () -> 42 bb0: { _1 = make_closure lambda[0](); - _0 = call const fn user.apply_and_throw(copy _1) -> [bb1]; + _0 = call const fn user.apply(copy _1) -> [bb1]; } bb1: { @@ -265,7 +1282,7 @@ fn user.test_apply_and_throw_pure() -> int { } // lambda[0] -fn .() -> null { +fn .() -> null { // Locals: let _0: null // _0 // return @@ -279,14 +1296,14 @@ fn .() -> null { } } -fn user.test_apply_explicit_throws() -> int { +fn user.test_apply_throwing() -> int { // Locals: let _0: int // _0 // return let _1: () -> void bb0: { _1 = make_closure lambda[0](); - _0 = call const fn user.apply_throwing(copy _1) -> [bb1]; + _0 = call const fn user.apply(copy _1) -> [bb1]; } bb1: { @@ -299,7 +1316,7 @@ fn user.test_apply_explicit_throws() -> int { } // lambda[0] -fn .() -> null { +fn .() -> null { // Locals: let _0: null // _0 // return @@ -308,14 +1325,14 @@ fn .() -> null { } } -fn user.test_apply_pure() -> int { +fn user.test_apply_with_helper_pure() -> int { // Locals: let _0: int // _0 // return let _1: () -> 42 bb0: { _1 = make_closure lambda[0](); - _0 = call const fn user.apply(copy _1) -> [bb1]; + _0 = call const fn user.apply_with_helper(copy _1) -> [bb1]; } bb1: { @@ -328,7 +1345,7 @@ fn user.test_apply_pure() -> int { } // lambda[0] -fn .() -> null { +fn .() -> null { // Locals: let _0: null // _0 // return @@ -342,14 +1359,14 @@ fn .() -> null { } } -fn user.test_apply_throwing() -> int { +fn user.test_apply_with_helper_throwing() -> int { // Locals: let _0: int // _0 // return let _1: () -> void bb0: { _1 = make_closure lambda[0](); - _0 = call const fn user.apply(copy _1) -> [bb1]; + _0 = call const fn user.apply_with_helper(copy _1) -> [bb1]; } bb1: { @@ -362,12 +1379,12 @@ fn user.test_apply_throwing() -> int { } // lambda[0] -fn .() -> null { +fn .() -> null { // Locals: let _0: null // _0 // return bb0: { - throw const "boom"; + throw const 42_i64; } } @@ -709,6 +1726,165 @@ fn .() -> null { } } +fn user.apply_with_arg(x: int, f: (int) -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x // param + let _2: (int) -> int // f // param + + bb0: { + _0 = call copy _2(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.apply_with_many_args(a: int, b: string, f: (int, string) -> string) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // a // param + let _2: string // b // param + let _3: (int, string) -> string // f // param + + bb0: { + _0 = call copy _3(copy _1, copy _2) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.test_many_args_pure() -> string { + // Locals: + let _0: string // _0 // return + let _1: (int, string) -> string + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_with_many_args(const 1_i64, const "hello", copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(n: int, s: string) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // n // param + let _2: string // s // param + + bb0: { + _0 = copy _2; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_mixed_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: (int) -> int + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_with_arg(const 5_i64, copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(n: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // n // param + + bb0: { + _0 = copy _1 * const 2_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_mixed_throwing() -> int { + // Locals: + let _0: int // _0 // return + let _1: (int) -> int + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_with_arg(const 5_i64, copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(n: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // n // param + let _2: bool + + bb0: { + _2 = copy _1 < const 0_i64; + branch copy _2 -> [bb4, bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _0 = copy _1 * const 2_i64; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + throw const "negative"; + } +} + fn user.test_nested_both_throw() -> int { // Locals: let _0: int // _0 // return @@ -857,3 +2033,79 @@ fn ., 1)>() -> null { return; } } + +fn user.make_pure() -> () -> int { + // Locals: + let _0: () -> int // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.make_thrower() -> () -> int { + // Locals: + let _0: () -> int // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_use_pure() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f + let _2: () -> int + + bb0: { + _1 = call const fn user.make_pure() -> [bb1]; + } + + bb1: { + _2 = copy _1; + _0 = call copy _2() -> [bb2]; + } + + bb2: { + goto -> bb3; + } + + bb3: { + return; + } +} + +fn user.test_use_thrower() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f + let _2: () -> int + + bb0: { + _1 = call const fn user.make_thrower() -> [bb1]; + } + + bb1: { + _2 = copy _1; + _0 = call copy _2() -> [bb2]; + } + + bb2: { + goto -> bb3; + } + + bb3: { + return; + } +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 129e309a7d..aa9b79d7f6 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -2,13 +2,111 @@ source: crates/baml_tests/src/generated_tests.rs --- === TIR2 === +function user.test_map_pure() -> int[] throws never { + { : int[] + let items = [1, 2, 3] : int[] + items.map((x) -> { ... }) : int[] + (x) -> { ... } : (x: int) -> int + { + x * 2 + } + } +} +lambda user.test_map_pure { +} +function user.test_map_throwing() -> int[] throws string { + { : int[] + let items = [1, 2, 3] : int[] + items.map((x) -> { ... }) : int[] + (x) -> { ... } : (x: int) -> int throws string + { + if (x == 2) + { + throw "found two" + } + x * 2 + } + } +} +lambda user.test_map_throwing { +} type user.ThrowingFn = () -> int throws string -function user.takes_throwing(f: () -> int throws string) -> int throws never { +function user.takes_throwing(f: () -> int throws string) -> int throws string { { : int f() : int } } type user.ThrowingFn$stream = unknown +function user.may_fail(x: int) -> string throws string { + { : "ok" + if (x == 0 : bool) : void + { : never + throw "zero" : "zero" + } + "ok" : "ok" + } +} +function user.test_no_catch() -> string throws string { + { : string + may_fail(1) : string + } +} +function user.test_catch_all() -> string throws never { + { : string | "caught" + catch (may_fail(1) : string) : unknown + catch (e) + _ => + "caught" : "caught" + } +} +function user.test_catch_string() -> string throws never { + { : string | "caught string" + catch (may_fail(1) : string) : unknown + catch (e) + _: string => + "caught string" : "caught string" + } +} +function user.test_catch_rethrow() -> string throws int { + { : string + catch (may_fail(1) : string) : unknown + catch (e) + _ => + throw 42 : never + } +} +function user.apply_inner(f: () -> int throws __throws_f) -> int throws __throws_f { + { : int + f() : int + } +} +function user.apply_outer(g: () -> int throws __throws_g) -> int throws __throws_g { + { : int + apply_inner(g) : int + } +} +function user.test_chained_pure() -> int throws never { + { : int + apply_outer(() -> int { ... }) : int + () -> int { ... } : () -> 42 + { + 42 + } + } +} +lambda user.test_chained_pure { +} +function user.test_chained_throwing() -> int throws string { + { : int + apply_outer(() -> int { ... }) : int + () -> int { ... } : () -> never throws string + { + throw "deep" + } + } +} +lambda user.test_chained_throwing { +} class user.PureHandler { run: () -> null } @@ -43,6 +141,49 @@ class user.MixedHandler$stream { safe: unknown risky: unknown } +function user.apply_explicit(f: () -> int throws string) -> int throws string { + { : int + f() : int + } +} +function user.test_explicit_param_pure() -> int throws string { + { : int + apply_explicit(() -> int { ... }) : int + () -> int { ... } : () -> 42 + { + 42 + } + } +} +lambda user.test_explicit_param_pure { +} +function user.test_explicit_param_throwing() -> int throws string { + { : int + apply_explicit(() -> int { ... }) : int + () -> int { ... } : () -> never throws string + { + throw "ok" + } + } +} +lambda user.test_explicit_param_throwing { +} +function user.apply_pure_only(f: () -> int) -> int throws never { + { : int + f() : int + } +} +function user.test_pure_only_pure() -> int throws never { + { : int + apply_pure_only(() -> int { ... }) : int + () -> int { ... } : () -> 42 + { + 42 + } + } +} +lambda user.test_pure_only_pure { +} function user.may_fail(x: int) -> string throws string { { : "ok" if (x == 0 : bool) : void @@ -57,7 +198,7 @@ function user.always_ok(x: int) -> string throws never { "always ok" : "always ok" } } -function user.caller() -> string throws never { +function user.caller() -> string throws string { { : string may_fail(1) : string } @@ -80,7 +221,154 @@ type user.PureCallback$stream = unknown type user.ExplicitPure$stream = unknown type user.Mapper$stream = unknown type user.Wrapper$stream = unknown -function user.apply(f: () -> int) -> int throws never { +function user.apply_guarded(f: () -> int throws __throws_f) -> int throws string | __throws_f { + { : int + let result = f() : int + if (result < 0 : bool) : void + { : never + throw "negative result" : "negative result" + } + result : int + } +} +function user.test_guarded_pure() -> int throws string { + { : int + apply_guarded(() -> int { ... }) : int + () -> int { ... } : () -> 42 + { + 42 + } + } +} +lambda user.test_guarded_pure { +} +function user.test_guarded_throwing() -> int throws int | string { + { : int + apply_guarded(() -> int { ... }) : int + () -> int { ... } : () -> never throws int + { + throw 99 + } + } +} +lambda user.test_guarded_throwing { +} +function user.run_pure(f: () -> int throws __throws_f) -> int throws __throws_f { + { : int + f() : int + } +} +function user.test_run_pure() -> int throws never { + { : int + run_pure(() -> int { ... }) : int + () -> int { ... } : () -> 42 + { + 42 + } + } +} +lambda user.test_run_pure { +} +function user.run_throwing(f: () -> int throws __throws_f) -> int throws __throws_f { + { : int + f() : int + } +} +function user.test_run_throwing() -> int throws string { + { : int + run_throwing(() -> int { ... }) : int + () -> int { ... } : () -> never throws string + { + throw "boom" + } + } +} +lambda user.test_run_throwing { +} +function user.run_two(f: () -> int throws __throws_f, g: () -> int throws __throws_g) -> int throws __throws_f | __throws_g { + { : int + f() + g() : int + } +} +function user.test_two_pure() -> int throws never { + { : int + run_two(() -> int { ... }, () -> int { ... }) : int + () -> int { ... } : () -> 1 + { + 1 + } + () -> int { ... } : () -> 2 + { + 2 + } + } +} +lambda user.test_two_pure { +} +lambda user.test_two_pure { +} +function user.test_two_one_throws() -> int throws string { + { : int + run_two(() -> int { ... }, () -> int { ... }) : int + () -> int { ... } : () -> 1 + { + 1 + } + () -> int { ... } : () -> never throws string + { + throw "boom" + } + } +} +lambda user.test_two_one_throws { +} +lambda user.test_two_one_throws { +} +function user.test_two_both_throw() -> int throws int | string { + { : int + run_two(() -> int { ... }, () -> int { ... }) : int + () -> int { ... } : () -> never throws string + { + throw "a" + } + () -> int { ... } : () -> never throws int + { + throw 42 + } + } +} +lambda user.test_two_both_throw { +} +lambda user.test_two_both_throw { +} +function user.map_it(x: int, f: (int) -> string throws __throws_f) -> string throws __throws_f { + { : string + f(x) : string + } +} +function user.test_map_pure() -> string throws never { + { : string + map_it(1, (n: int) -> string { ... }) : string + (n: int) -> string { ... } : (n: int) -> "ok" + { + "ok" + } + } +} +lambda user.test_map_pure { +} +function user.test_map_throwing() -> string throws string { + { : string + map_it(1, (n: int) -> string { ... }) : string + (n: int) -> string { ... } : (n: int) -> never throws string + { + throw "bad" + } + } +} +lambda user.test_map_throwing { +} +function user.apply(f: () -> int throws __throws_f) -> int throws __throws_f { { : int f() : int } @@ -96,7 +384,7 @@ function user.test_apply_pure() -> int throws never { } lambda user.test_apply_pure { } -function user.test_apply_throwing() -> int throws never { +function user.test_apply_throwing() -> int throws string { { : int apply(() -> int { ... }) : int () -> int { ... } : () -> never throws string @@ -107,12 +395,12 @@ function user.test_apply_throwing() -> int throws never { } lambda user.test_apply_throwing { } -function user.apply_throwing(f: () -> int throws string) -> int throws never { +function user.apply_throwing(f: () -> int throws string) -> int throws string { { : int f() : int } } -function user.test_apply_explicit_throws() -> int throws never { +function user.test_apply_explicit_throws() -> int throws string { { : int apply_throwing(() -> int { ... }) : int () -> int { ... } : () -> never throws string @@ -123,7 +411,7 @@ function user.test_apply_explicit_throws() -> int throws never { } lambda user.test_apply_explicit_throws { } -function user.apply_and_throw(f: () -> int) -> int throws never { +function user.apply_and_throw(f: () -> int throws __throws_f) -> int throws string | __throws_f { { : int let result = f() : int if (result < 0 : bool) : void @@ -133,7 +421,7 @@ function user.apply_and_throw(f: () -> int) -> int throws never { result : int } } -function user.test_apply_and_throw_pure() -> int throws never { +function user.test_apply_and_throw_pure() -> int throws string { { : int apply_and_throw(() -> int { ... }) : int () -> int { ... } : () -> 42 @@ -144,7 +432,40 @@ function user.test_apply_and_throw_pure() -> int throws never { } lambda user.test_apply_and_throw_pure { } -function user.test_explicit_throws_match() -> int throws never { +function user.helper_with_body_throw() -> int throws string { + { : never + throw "helper boom" : "helper boom" + } +} +function user.apply_with_helper(f: () -> int throws __throws_f) -> int throws string | __throws_f { + { : int + helper_with_body_throw() : int + f() : int + } +} +function user.test_apply_with_helper_pure() -> int throws string { + { : int + apply_with_helper(() -> int { ... }) : int + () -> int { ... } : () -> 42 + { + 42 + } + } +} +lambda user.test_apply_with_helper_pure { +} +function user.test_apply_with_helper_throwing() -> int throws int | string { + { : int + apply_with_helper(() -> int { ... }) : int + () -> int { ... } : () -> never throws int + { + throw 42 + } + } +} +lambda user.test_apply_with_helper_throwing { +} +function user.test_explicit_throws_match() -> int throws string { { : never let f = : () -> never throws string () -> int throws string { ... } : () -> never throws string @@ -168,7 +489,7 @@ function user.test_explicit_throws_never_pure() -> int throws never { } lambda user.test_explicit_throws_never_pure { } -function user.test_explicit_throws_wider_than_body() -> int throws never { +function user.test_explicit_throws_wider_than_body() -> int throws string { { : 42 let f = : () -> 42 throws string () -> int throws string { ... } : () -> 42 throws string @@ -193,7 +514,7 @@ function user.test_pure_lambda() -> int throws never { } lambda user.test_pure_lambda { } -function user.test_throwing_lambda() -> int throws never { +function user.test_throwing_lambda() -> int throws string { { : never let f = : () -> never throws string () -> int { ... } : () -> never throws string @@ -205,7 +526,7 @@ function user.test_throwing_lambda() -> int throws never { } lambda user.test_throwing_lambda { } -function user.test_throwing_int() -> string throws never { +function user.test_throwing_int() -> string throws int { { : never let f = : () -> never throws int () -> string { ... } : () -> never throws int @@ -217,7 +538,7 @@ function user.test_throwing_int() -> string throws never { } lambda user.test_throwing_int { } -function user.test_conditional_throw(x: int) -> int throws never { +function user.test_conditional_throw(x: int) -> int throws string { { : int let f = : (n: int) -> int throws string (n: int) -> int { ... } : (n: int) -> int throws string @@ -233,7 +554,7 @@ function user.test_conditional_throw(x: int) -> int throws never { } lambda user.test_conditional_throw { } -function user.test_multi_throw_types(x: int) -> string throws never { +function user.test_multi_throw_types(x: int) -> string throws int | string { { : "ok" let f = : (n: int) -> "ok" throws int | string (n: int) -> string { ... } : (n: int) -> "ok" throws int | string @@ -258,7 +579,54 @@ function user.test_throws_never_but_throws() -> int throws never { lambda user.test_throws_never_but_throws { !! 213..219: throws contract violation: `never` is missing string } -function user.test_nested_inner_throws() -> int throws never { +function user.apply_with_arg(x: int, f: (int) -> int throws __throws_f) -> int throws __throws_f { + { : int + f(x) : int + } +} +function user.test_mixed_pure() -> int throws never { + { : int + apply_with_arg(5, (n: int) -> int { ... }) : int + (n: int) -> int { ... } : (n: int) -> int + { + n * 2 + } + } +} +lambda user.test_mixed_pure { +} +function user.test_mixed_throwing() -> int throws string { + { : int + apply_with_arg(5, (n: int) -> int { ... }) : int + (n: int) -> int { ... } : (n: int) -> int throws string + { + if (n < 0) + { + throw "negative" + } + n * 2 + } + } +} +lambda user.test_mixed_throwing { +} +function user.apply_with_many_args(a: int, b: string, f: (int, string) -> string throws __throws_f) -> string throws __throws_f { + { : string + f(a, b) : string + } +} +function user.test_many_args_pure() -> string throws never { + { : string + apply_with_many_args(1, "hello", (n: int, s: string) -> string { ... }) : string + (n: int, s: string) -> string { ... } : (n: int, s: string) -> string + { + s + } + } +} +lambda user.test_many_args_pure { +} +function user.test_nested_inner_throws() -> int throws string { { : never let outer = : () -> never throws string () -> int { ... } : () -> never throws string @@ -277,7 +645,7 @@ lambda user.test_nested_inner_throws { } lambda user.test_nested_inner_throws { } -function user.test_nested_outer_throws() -> int throws never { +function user.test_nested_outer_throws() -> int throws string { { : never let outer = : () -> never throws string () -> int { ... } : () -> never throws string @@ -296,10 +664,10 @@ lambda user.test_nested_outer_throws { } lambda user.test_nested_outer_throws { } -function user.test_nested_both_throw() -> int throws never { +function user.test_nested_both_throw() -> int throws int | string { { : never - let outer = : () -> never throws int | unknown - () -> int { ... } : () -> never throws int | unknown + let outer = : () -> never throws int | string + () -> int { ... } : () -> never throws int | string { let inner = ... () -> int { ... } @@ -318,3 +686,25 @@ lambda user.test_nested_both_throw { } lambda user.test_nested_both_throw { } +function user.make_pure() -> () -> int throws never { + { : () -> int + } + !! 142..165: missing return: expected `() -> int` +} +function user.make_thrower() -> () -> int throws string throws never { + { : () -> int throws string + } + !! 263..297: missing return: expected `() -> int throws string` +} +function user.test_use_pure() -> int throws never { + { : int + let f = make_pure() : () -> int + f() : int + } +} +function user.test_use_thrower() -> int throws string { + { : int + let f = make_thrower() : () -> int throws string + f() : int + } +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index 589b199aab..ca1fc4928c 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -2,6 +2,54 @@ source: crates/baml_tests/src/generated_tests.rs --- === COMPILER2 DIAGNOSTICS === + [validation] Error: Duplicate function `may_fail` + ╭─[ catch_absorbs_throws.baml:3:10 ] + │ + 3 │ function may_fail(x: int) -> string throws string { + │ ────┬─── + │ ╰───── first defined as function here + │ + ├─[ fn_decl_throws.baml:4:10 ] + │ + 4 │ function may_fail(x: int) -> string throws string { + │ ────┬─── + │ ╰───── duplicate function definition + │ + │ Note: Error code: E0011 +───╯ + + [validation] Error: Duplicate function `test_map_pure` + ╭─[ array_map_throws.baml:4:10 ] + │ + 4 │ function test_map_pure() -> int[] { + │ ──────┬────── + │ ╰──────── first defined as function here + │ + ├─[ hof_rethrows.baml:37:10 ] + │ + 37 │ function test_map_pure() -> string { + │ ──────┬────── + │ ╰──────── duplicate function definition + │ + │ Note: Error code: E0011 +────╯ + + [validation] Error: Duplicate function `test_map_throwing` + ╭─[ array_map_throws.baml:10:10 ] + │ + 10 │ function test_map_throwing() -> int[] { + │ ────────┬──────── + │ ╰────────── first defined as function here + │ + ├─[ hof_rethrows.baml:41:10 ] + │ + 41 │ function test_map_throwing() -> string { + │ ────────┬──────── + │ ╰────────── duplicate function definition + │ + │ Note: Error code: E0011 +────╯ + [type] Warning: extraneous throws declaration: string ╭─[ lambda_throws_explicit.baml:17:27 ] │ @@ -41,3 +89,27 @@ source: crates/baml_tests/src/generated_tests.rs │ │ Note: Error code: E0001 ────╯ + + [type] Error: missing return: expected `() -> int` + ╭─[ returned_closures.baml:4:36 ] + │ + 4 │ ╭─▶ function make_pure() -> (() -> int) { + ┆ ┆ + 6 │ ├─▶ } + │ │ + │ ╰─────── missing return: expected `() -> int` + │ + │ Note: Error code: E0001 +───╯ + + [type] Error: missing return: expected `() -> int throws string` + ╭─[ returned_closures.baml:9:53 ] + │ + 9 │ ╭─▶ function make_thrower() -> (() -> int throws string) { + ┆ ┆ + 11 │ ├─▶ } + │ │ + │ ╰─────── missing return: expected `() -> int throws string` + │ + │ Note: Error code: E0001 +────╯ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 01c949daa0..caba8aa39b 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -31,18 +31,94 @@ function user.apply_and_throw(f: () -> int) -> int { throw } +function user.apply_explicit(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.apply_guarded(f: () -> int) -> int { + load_var f + call_indirect + store_var result + load_var result + load_const 0 + cmp_op < + pop_jump_if_false L0 + jump L1 + + L0: + load_var result + return + + L1: + load_const "negative result" + throw +} + +function user.apply_inner(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.apply_outer(g: () -> int) -> int { + load_var g + call user.apply_inner + return +} + +function user.apply_pure_only(f: () -> int) -> int { + load_var f + call_indirect + return +} + function user.apply_throwing(f: () -> int) -> int { load_var f call_indirect return } +function user.apply_with_arg(x: int, f: (int) -> int) -> int { + load_var x + load_var f + call_indirect + return +} + +function user.apply_with_helper(f: () -> int) -> int { + call user.helper_with_body_throw + pop 1 + load_var f + call_indirect + return +} + +function user.apply_with_many_args(a: int, b: string, f: (int, string) -> string) -> string { + load_var a + load_var b + load_var f + call_indirect + return +} + function user.caller() -> string { load_const 1 call user.may_fail return } +function user.helper_with_body_throw() -> int { + load_const "helper boom" + throw +} + +function user.make_pure() -> () -> int { + load_const null + return +} + function user.make_pure_handler() -> PureHandler { alloc_instance PureHandler copy 0 @@ -51,6 +127,11 @@ function user.make_pure_handler() -> PureHandler { return } +function user.make_thrower() -> () -> int { + load_const null + return +} + function user.make_throwing_handler() -> ThrowingHandler { alloc_instance ThrowingHandler copy 0 @@ -59,6 +140,13 @@ function user.make_throwing_handler() -> ThrowingHandler { return } +function user.map_it(x: int, f: (int) -> string) -> string { + load_var x + load_var f + call_indirect + return +} + function user.may_fail(x: int) -> string { load_var x load_const 0 @@ -75,6 +163,31 @@ function user.may_fail(x: int) -> string { throw } +function user.run_pure(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.run_throwing(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.run_two(f: () -> int, g: () -> int) -> int { + load_var f + call_indirect + store_var _3 + load_var g + call_indirect + store_var _4 + load_var _3 + load_var _4 + bin_op + + return +} + function user.safe_caller() -> string { load_const 1 call user.may_fail @@ -117,6 +230,77 @@ function user.test_apply_throwing() -> int { return } +function user.test_apply_with_helper_pure() -> int { + make_closure ., 0 + call user.apply_with_helper + return +} + +function user.test_apply_with_helper_throwing() -> int { + make_closure ., 0 + call user.apply_with_helper + return +} + +function user.test_catch_all() -> string { + load_const 1 + call user.may_fail + jump L0 + load_var e + throw_if_panic + load_const "caught" + + L0: + return +} + +function user.test_catch_rethrow() -> string { + load_const 1 + call user.may_fail + jump L0 + load_var e + throw_if_panic + load_const 42 + throw + + L0: + return +} + +function user.test_catch_string() -> string { + load_const 1 + call user.may_fail + jump L2 + load_var e + type_tag + load_const 1 + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_var e + throw + + L1: + load_const "caught string" + + L2: + return +} + +function user.test_chained_pure() -> int { + make_closure ., 0 + call user.apply_outer + return +} + +function user.test_chained_throwing() -> int { + make_closure ., 0 + call user.apply_outer + return +} + function user.test_conditional_throw(x: int) -> int { load_var x make_closure ., 0 @@ -124,6 +308,18 @@ function user.test_conditional_throw(x: int) -> int { return } +function user.test_explicit_param_pure() -> int { + make_closure ., 0 + call user.apply_explicit + return +} + +function user.test_explicit_param_throwing() -> int { + make_closure ., 0 + call user.apply_explicit + return +} + function user.test_explicit_throws_match() -> int { make_closure ., 0 call_indirect @@ -142,6 +338,54 @@ function user.test_explicit_throws_wider_than_body() -> int { return } +function user.test_guarded_pure() -> int { + make_closure ., 0 + call user.apply_guarded + return +} + +function user.test_guarded_throwing() -> int { + make_closure ., 0 + call user.apply_guarded + return +} + +function user.test_many_args_pure() -> string { + load_const 1 + load_const "hello" + make_closure ., 0 + call user.apply_with_many_args + return +} + +function user.test_map_pure() -> string { + load_const 1 + make_closure ., 0 + call user.map_it + return +} + +function user.test_map_throwing() -> string { + load_const 1 + make_closure ., 0 + call user.map_it + return +} + +function user.test_mixed_pure() -> int { + load_const 5 + make_closure ., 0 + call user.apply_with_arg + return +} + +function user.test_mixed_throwing() -> int { + load_const 5 + make_closure ., 0 + call user.apply_with_arg + return +} + function user.test_multi_throw_types(x: int) -> string { load_var x make_closure ., 0 @@ -167,12 +411,36 @@ function user.test_nested_outer_throws() -> int { return } +function user.test_no_catch() -> string { + load_const 1 + call user.may_fail + return +} + function user.test_pure_lambda() -> int { make_closure ., 0 call_indirect return } +function user.test_pure_only_pure() -> int { + make_closure ., 0 + call user.apply_pure_only + return +} + +function user.test_run_pure() -> int { + make_closure ., 0 + call user.run_pure + return +} + +function user.test_run_throwing() -> int { + make_closure ., 0 + call user.run_throwing + return +} + function user.test_throwing_int() -> string { make_closure ., 0 call_indirect @@ -190,3 +458,36 @@ function user.test_throws_never_but_throws() -> int { call_indirect return } + +function user.test_two_both_throw() -> int { + make_closure ., 0 + make_closure ., 0 + call user.run_two + return +} + +function user.test_two_one_throws() -> int { + make_closure ., 0 + make_closure ., 0 + call user.run_two + return +} + +function user.test_two_pure() -> int { + make_closure ., 0 + make_closure ., 0 + call user.run_two + return +} + +function user.test_use_pure() -> int { + call user.make_pure + call_indirect + return +} + +function user.test_use_thrower() -> int { + call user.make_thrower + call_indirect + return +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__array_map_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__array_map_throws.snap new file mode 100644 index 0000000000..799895a7f7 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__array_map_throws.snap @@ -0,0 +1,29 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +// === Array.map with throwing callback === + +// Pure map callback +function test_map_pure() -> int[] { + let items: int[] = [1, 2, 3] + items + .map( + (x) -> { + x * 2 + }, + ) +} + +// Throwing map callback +function test_map_throwing() -> int[] { + let items: int[] = [1, 2, 3] + items + .map( + (x) -> { + if (x == 2) { + throw "found two" + } + x * 2 + }, + ) +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__catch_absorbs_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__catch_absorbs_throws.snap new file mode 100644 index 0000000000..3dd5257123 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__catch_absorbs_throws.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Expected token/node of kind LLM_FUNCTION_BODY or EXPR_FUNCTION_BODY, but found THROWS_CLAUSE at catch_absorbs_throws.baml:3:36 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__chained_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__chained_hof.snap new file mode 100644 index 0000000000..be9c6fdb29 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__chained_hof.snap @@ -0,0 +1,30 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +// === Chained HOF: throws through multiple layers === + +function apply_inner(f: () -> int) -> int { + f() +} + +function apply_outer(g: () -> int) -> int { + apply_inner(g) +} + +// Pure through two layers +function test_chained_pure() -> int { + apply_outer( + () -> int { + 42 + }, + ) +} + +// Throwing through two layers +function test_chained_throwing() -> int { + apply_outer( + () -> int { + throw "deep" + }, + ) +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__explicit_param_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__explicit_param_throws.snap new file mode 100644 index 0000000000..161aea4e88 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__explicit_param_throws.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Unexpected additional element at explicit_param_throws.baml:4:37 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_own_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_own_throws.snap new file mode 100644 index 0000000000..097d4e94ed --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_own_throws.snap @@ -0,0 +1,32 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +// === HOF with its own throws + callback throws === +// A function that both throws itself AND rethrows callback effects. +// The effective throws should be the UNION of both. + +function apply_guarded(f: () -> int) -> int { + let result = f() + if (result < 0) { + throw "negative result" + } + result +} + +// Pure callback: function's own throw should still appear +function test_guarded_pure() -> int { + apply_guarded( + () -> int { + 42 + }, + ) +} + +// Throwing callback: both the function's own throw and the callback's throw +function test_guarded_throwing() -> int { + apply_guarded( + () -> int { + throw 99 + }, + ) +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_rethrows.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_rethrows.snap new file mode 100644 index 0000000000..6b07e7e834 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_rethrows.snap @@ -0,0 +1,89 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +// === HOF rethrows: effect propagation through higher-order functions === + +// Basic rethrows - pure callback +function run_pure(f: () -> int) -> int { + f() +} +function test_run_pure() -> int { + run_pure( + () -> int { + 42 + }, + ) +} + +// Basic rethrows - throwing callback +function run_throwing(f: () -> int) -> int { + f() +} +function test_run_throwing() -> int { + run_throwing( + () -> int { + throw "boom" + }, + ) +} + +// Rethrows with multiple callback params +function run_two(f: () -> int, g: () -> int) -> int { + f() + g() +} + +function test_two_pure() -> int { + run_two( + () -> int { + 1 + }, + () -> int { + 2 + }, + ) +} + +function test_two_one_throws() -> int { + run_two( + () -> int { + 1 + }, + () -> int { + throw "boom" + }, + ) +} + +function test_two_both_throw() -> int { + run_two( + () -> int { + throw "a" + }, + () -> int { + throw 42 + }, + ) +} + +// Rethrows with parametric callback +function map_it(x: int, f: (int) -> string) -> string { + f(x) +} + +function test_map_pure() -> string { + map_it( + 1, + (n: int) -> string { + "ok" + }, + ) +} + +function test_map_throwing() -> string { + map_it( + 1, + (n: int) -> string { + throw "bad" + }, + ) +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__mixed_params.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__mixed_params.snap new file mode 100644 index 0000000000..75d80b39e9 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__mixed_params.snap @@ -0,0 +1,45 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +// === Mixed: function with both function and non-function params === + +// Only the function param should get an effect var +function apply_with_arg(x: int, f: (int) -> int) -> int { + f(x) +} + +function test_mixed_pure() -> int { + apply_with_arg( + 5, + (n: int) -> int { + n * 2 + }, + ) +} + +function test_mixed_throwing() -> int { + apply_with_arg( + 5, + (n: int) -> int { + if (n < 0) { + throw "negative" + } + n * 2 + }, + ) +} + +// Multiple non-function params with one callback +function apply_with_many_args(a: int, b: string, f: (int, string) -> string) -> string { + f(a, b) +} + +function test_many_args_pure() -> string { + apply_with_many_args( + 1, + "hello", + (n: int, s: string) -> string { + s + }, + ) +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__returned_closures.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__returned_closures.snap new file mode 100644 index 0000000000..a1d4202ca1 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__returned_closures.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Unexpected additional element at returned_closures.baml:9:38 diff --git a/baml_language/crates/baml_tests/snapshots/lambda_advanced/baml_tests__lambda_advanced__04_tir.snap b/baml_language/crates/baml_tests/snapshots/lambda_advanced/baml_tests__lambda_advanced__04_tir.snap index 92396012f0..6a3a4b39ec 100644 --- a/baml_language/crates/baml_tests/snapshots/lambda_advanced/baml_tests__lambda_advanced__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/lambda_advanced/baml_tests__lambda_advanced__04_tir.snap @@ -5,8 +5,8 @@ source: crates/baml_tests/src/generated_tests.rs function user.test_chained_map() -> int[] throws never { { : int[] let items = [1, 2, 3] : int[] - let step1 = items.map((x) -> { ... }) : int[] - step1.map((x) -> { ... }) : int[] + let step1 = items.map((x) -> { ... }) : int[] + step1.map((x) -> { ... }) : int[] (x) -> { ... } : (x: int) -> int { x + 1 @@ -20,7 +20,7 @@ lambda user.test_chained_map { function user.test_map_string_to_int() -> int[] throws never { { : int[] let words = ["hello", "world"] : string[] - words.map((w: string) -> int { ... }) : int[] + words.map((w: string) -> int { ... }) : int[] (w: string) -> int { ... } : (w: string) -> int { w.length() @@ -142,7 +142,7 @@ function user.test_map_with_capture() -> int[] throws never { { : int[] let items = [1, 2, 3] : int[] let factor = 10 : 10 -> int - items.map((x) -> { ... }) : int[] + items.map((x) -> { ... }) : int[] (x) -> { ... } : (x: int) -> int { x * factor @@ -154,7 +154,7 @@ lambda user.test_map_with_capture { function user.test_nested_map() -> int[][] throws never { { : int[][] let matrix = [[1, 2], [3, 4]] : int[][] - matrix.map((row: int[]) -> { ... }) : int[][] + matrix.map((row: int[]) -> { ... }) : int[][] (row: int[]) -> { ... } : (row: int[]) -> int[] { row.map((x) -> { ... }) @@ -168,7 +168,7 @@ lambda user.test_nested_map { function user.test_map_multi_statement() -> int[] throws never { { : int[] let items = [1, 2, 3] : int[] - items.map((x) -> { ... }) : int[] + items.map((x) -> { ... }) : int[] (x) -> { ... } : (x: int) -> int { let doubled = x * 2 @@ -182,7 +182,7 @@ lambda user.test_map_multi_statement { function user.test_map_entries() -> string[] throws never { { : string[] let m = map { "a": 1, "b": 2 } : map - m.map((k: string, v: int) -> string { ... }) : string[] + m.map((k: string, v: int) -> string { ... }) : string[] (k: string, v: int) -> string { ... } : (k: string, v: int) -> string { k @@ -194,7 +194,7 @@ lambda user.test_map_entries { function user.test_map_values_transform() -> int[] throws never { { : int[] let m = map { "x": 10, "y": 20 } : map - m.map_values((v) -> { ... }) : int[] + m.map_values((v) -> { ... }) : int[] (v) -> { ... } : (v: int) -> int { v * 2 @@ -206,7 +206,7 @@ lambda user.test_map_values_transform { function user.test_map_keys_transform() -> string[] throws never { { : string[] let m = map { "hello": 1 } : map - m.map_keys((k) -> { ... }) : string[] + m.map_keys((k) -> { ... }) : string[] (k) -> { ... } : (k: string) -> string { k @@ -223,7 +223,7 @@ function user.test_lambda_variable_then_pass() -> int[] throws never { { x / 2 } - items.map(halve) : int[] + items.map(halve) : int[] } } lambda user.test_lambda_variable_then_pass { @@ -232,7 +232,7 @@ function user.test_shadowing() -> string throws never { { : string let x = "999" : "999" -> string let items = [1, 2, 3] : int[] - let mapped = items.map((x) -> { ... }) : int[] + let mapped = items.map((x) -> { ... }) : int[] for y in mapped { : void let i = 0 : 0 -> int @@ -277,7 +277,7 @@ lambda user.test_optional_return { function user.test_fully_inferred_map() -> int[] throws never { { : int[] let items = [1, 2, 3] : int[] - items.map((x) -> { ... }) : int[] + items.map((x) -> { ... }) : int[] (x) -> { ... } : (x: int) -> int { x * 2 @@ -289,7 +289,7 @@ lambda user.test_fully_inferred_map { function user.test_map_type_change_inferred() -> string[] throws never { { : (string | "one" | "two" | "other")[] let nums = [1, 2, 3] : int[] - nums.map((n) -> { ... }) : (string | "one" | "two" | "other")[] + nums.map((n) -> { ... }) : (string | "one" | "two" | "other")[] (n) -> { ... } : (n: int) -> "one" | "two" | "other" { match (n) { 1 => "one", 2 => "two", _ => "other" } @@ -301,7 +301,7 @@ lambda user.test_map_type_change_inferred { function user.test_map_values_fully_inferred() -> string[] throws never { { : (string | "big" | "small")[] let m = map { "x": 10, "y": 20 } : map - m.map_values((v) -> { ... }) : (string | "big" | "small")[] + m.map_values((v) -> { ... }) : (string | "big" | "small")[] (v) -> { ... } : (v: int) -> "big" | "small" { if (v > 15) @@ -320,7 +320,7 @@ lambda user.test_map_values_fully_inferred { function user.test_literal_return_inference() -> int[] throws never { { : (int | 0)[] let items = [1, 2, 3] : int[] - items.map((x) -> { ... }) : (int | 0)[] + items.map((x) -> { ... }) : (int | 0)[] (x) -> { ... } : (x: int) -> 0 { 0 @@ -332,8 +332,8 @@ lambda user.test_literal_return_inference { function user.test_compose_inferred() -> int[] throws never { { : int[] let items = [1, 2, 3] : int[] - let step1 = items.map((x) -> { ... }) : int[] - step1.map((x) -> { ... }) : int[] + let step1 = items.map((x) -> { ... }) : int[] + step1.map((x) -> { ... }) : int[] (x) -> { ... } : (x: int) -> int { x + 1 @@ -478,7 +478,7 @@ function user.test_capture_in_map_inferred() -> int[] throws never { { : int[] let items = [1, 2, 3] : int[] let factor = 10 : 10 -> int - items.map((x) -> { ... }) : int[] + items.map((x) -> { ... }) : int[] (x) -> { ... } : (x: int) -> int { x * factor @@ -536,7 +536,7 @@ lambda user.test_capture_lambda_inferred { function user.test_map_optional_inferred() -> int[] throws never { { : (int | 0 | 1)[] let items = [1, null, 3] : int?[] - items.map((x) -> { ... }) : (int | 0 | 1)[] + items.map((x) -> { ... }) : (int | 0 | 1)[] (x) -> { ... } : (x: int?) -> 0 | 1 { match (x) { null => 0, _ => 1 } @@ -549,7 +549,7 @@ function user.test_shadowing_inferred() -> int[] throws never { { : int[] let x = "not a number" : "not a number" -> string let items = [1, 2, 3] : int[] - items.map((x) -> { ... }) : int[] + items.map((x) -> { ... }) : int[] (x) -> { ... } : (x: int) -> int { x + 1 @@ -568,7 +568,7 @@ lambda user.test_iife_inferred { function user.test_map_entries_inferred() -> string[] throws never { { : string[] let m = map { "a": 1, "b": 2 } : map - m.map((k, v) -> { ... }) : string[] + m.map((k, v) -> { ... }) : string[] (k, v) -> { ... } : (k: string, v: int) -> string { k diff --git a/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_tir.snap b/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_tir.snap index 76a19013f0..4a0396a386 100644 --- a/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_tir.snap @@ -41,7 +41,7 @@ lambda user.test_zero_param { function user.test_map_inferred() -> int[] throws never { { : int[] let items = [1, 2, 3] : int[] - items.map((x) -> { ... }) : int[] + items.map((x) -> { ... }) : int[] (x) -> { ... } : (x: int) -> int { x * 2 @@ -50,7 +50,7 @@ function user.test_map_inferred() -> int[] throws never { } lambda user.test_map_inferred { } -function user.test_throws() -> int throws never { +function user.test_throws() -> int throws string { { : int let risky = : (x: int) -> int throws string (x: int) -> int throws string { ... } : (x: int) -> int throws string diff --git a/baml_language/crates/baml_tests/snapshots/lambda_errors/baml_tests__lambda_errors__04_tir.snap b/baml_language/crates/baml_tests/snapshots/lambda_errors/baml_tests__lambda_errors__04_tir.snap index 6f3b37600f..e7b834523e 100644 --- a/baml_language/crates/baml_tests/snapshots/lambda_errors/baml_tests__lambda_errors__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/lambda_errors/baml_tests__lambda_errors__04_tir.snap @@ -31,7 +31,7 @@ lambda user.test_no_annotation_no_context { function user.test_arity_mismatch() -> int[] throws never { { : int[] let items = [1, 2, 3] : int[] - items.map((x, y) -> { ... }) : int[] + items.map((x, y) -> { ... }) : int[] (x, y) -> { ... } : (x: int, y: unknown) -> int { x @@ -44,7 +44,7 @@ lambda user.test_arity_mismatch { function user.test_param_type_mismatch() -> int[] throws never { { : (int | string)[] let items = [1, 2, 3] : int[] - items.map((x: string) -> { ... }) : (int | string)[] + items.map((x: string) -> { ... }) : (int | string)[] (x: string) -> { ... } : (x: string) -> string { x diff --git a/baml_language/crates/baml_tests/snapshots/namespaces_type_resolution/baml_tests__namespaces_type_resolution__04_tir.snap b/baml_language/crates/baml_tests/snapshots/namespaces_type_resolution/baml_tests__namespaces_type_resolution__04_tir.snap index 9c45ee294b..0ff71e5aa4 100644 --- a/baml_language/crates/baml_tests/snapshots/namespaces_type_resolution/baml_tests__namespaces_type_resolution__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/namespaces_type_resolution/baml_tests__namespaces_type_resolution__04_tir.snap @@ -43,7 +43,7 @@ function user.RootAliasParam(cfg: user.ConfigAlias) -> string throws never { cfg.key : string } } -function user.RootThrowClass(x: int) -> string throws never { +function user.RootThrowClass(x: int) -> string throws user.Error { { : "ok" match (x : int) : "ok" 0 => @@ -52,7 +52,7 @@ function user.RootThrowClass(x: int) -> string throws never { "ok" : "ok" } } -function user.RootThrowEnum(x: int) -> string throws never { +function user.RootThrowEnum(x: int) -> string throws user.Status.Failed { { : "ok" match (x : int) : "ok" 0 => @@ -176,7 +176,7 @@ function user.auth.AuthAliasParam(t: user.auth.TokenAlias) -> string throws neve t.value : string } } -function user.auth.AuthThrowClass(x: int) -> string throws never { +function user.auth.AuthThrowClass(x: int) -> string throws user.auth.Error { { : "ok" match (x : int) : "ok" 0 => @@ -185,7 +185,7 @@ function user.auth.AuthThrowClass(x: int) -> string throws never { "ok" : "ok" } } -function user.auth.AuthThrowEnum(x: int) -> string throws never { +function user.auth.AuthThrowEnum(x: int) -> string throws user.auth.Role.Guest { { : "ok" match (x : int) : "ok" 0 => @@ -220,7 +220,7 @@ function user.auth.AuthMatchesRootStatus(s: user.Status) -> string throws never "pending" : "pending" } } -function user.auth.AuthThrowsRootError(x: int) -> string throws never { +function user.auth.AuthThrowsRootError(x: int) -> string throws user.auth.Error { { : "ok" match (x : int) : "ok" 0 => @@ -330,7 +330,7 @@ function user.llm.LlmAliasParam(cfg: user.llm.ConfigAlias) -> string throws neve cfg.model : string } } -function user.llm.LlmThrowClass(x: int) -> string throws never { +function user.llm.LlmThrowClass(x: int) -> string throws user.llm.Error { { : "ok" match (x : int) : "ok" 0 => @@ -339,7 +339,7 @@ function user.llm.LlmThrowClass(x: int) -> string throws never { "ok" : "ok" } } -function user.llm.LlmThrowEnum(x: int) -> string throws never { +function user.llm.LlmThrowEnum(x: int) -> string throws user.llm.Status.Done { { : "ok" match (x : int) : "ok" 0 => @@ -384,7 +384,7 @@ function user.llm.LlmUsesRootAlias(cfg: user.ConfigAlias) -> string throws never cfg.key : string } } -function user.llm.LlmThrowsRootError(x: int) -> string throws never { +function user.llm.LlmThrowsRootError(x: int) -> string throws user.llm.Error { { : "ok" match (x : int) : "ok" 0 => @@ -393,7 +393,7 @@ function user.llm.LlmThrowsRootError(x: int) -> string throws never { "ok" : "ok" } } -function user.llm.LlmThrowsRootEnum(x: int) -> string throws never { +function user.llm.LlmThrowsRootEnum(x: int) -> string throws user.Status.Failed { { : "ok" match (x : int) : "ok" 0 => @@ -506,7 +506,7 @@ function user.llm.openai.OpenAIEnumMatch(m: user.llm.openai.Model) -> string thr "o3" : "o3" } } -function user.llm.openai.OpenAIThrowClass(x: int) -> string throws never { +function user.llm.openai.OpenAIThrowClass(x: int) -> string throws user.llm.openai.Error { { : "ok" match (x : int) : "ok" 0 => @@ -515,7 +515,7 @@ function user.llm.openai.OpenAIThrowClass(x: int) -> string throws never { "ok" : "ok" } } -function user.llm.openai.OpenAIThrowEnum(x: int) -> string throws never { +function user.llm.openai.OpenAIThrowEnum(x: int) -> string throws user.llm.openai.Model.O3 { { : "ok" match (x : int) : "ok" 0 => diff --git a/baml_language/crates/baml_tests/snapshots/null_handling/baml_tests__null_handling__04_tir.snap b/baml_language/crates/baml_tests/snapshots/null_handling/baml_tests__null_handling__04_tir.snap index 5c1282d638..3e3e859a46 100644 --- a/baml_language/crates/baml_tests/snapshots/null_handling/baml_tests__null_handling__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/null_handling/baml_tests__null_handling__04_tir.snap @@ -72,7 +72,7 @@ function user.OptionalIndex(node: user.Node?, items: int[]?, lookup: map string?, nonNullCb: (int) -> string, name: string?) -> string throws never { +function user.OptionalCall(callback: (int) -> string? throws __throws_callback, nonNullCb: (int) -> string throws __throws_nonNullCb, name: string?) -> string throws __throws_callback | __throws_nonNullCb { { : never let c1 = callback?.(42) : string? let c2 = callback?.(42).length : ((self: baml.String) -> int)? diff --git a/baml_language/crates/baml_tests/snapshots/type_annotation/baml_tests__type_annotation__04_tir.snap b/baml_language/crates/baml_tests/snapshots/type_annotation/baml_tests__type_annotation__04_tir.snap index db9d668b57..7199953bd6 100644 --- a/baml_language/crates/baml_tests/snapshots/type_annotation/baml_tests__type_annotation__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/type_annotation/baml_tests__type_annotation__04_tir.snap @@ -8,7 +8,7 @@ type user.MetaUnion = type | string type user.MetaAlias$stream = unknown type user.MetaList$stream = unknown[] type user.MetaUnion$stream = unknown | string -function user.test_type_annotation(function_name: string) -> string throws never { +function user.test_type_annotation(function_name: string) -> string throws baml.errors.InvalidArgument { { : never let t = baml.llm.get_return_type(function_name) : type return "ok" : "ok" @@ -29,7 +29,7 @@ class user.RiskFactor$stream { class user.MetaField$stream { type: unknown } -function user.test_let_binding(function_name: string) -> string throws never { +function user.test_let_binding(function_name: string) -> string throws baml.errors.InvalidArgument { { : never let t = baml.llm.get_return_type(function_name) : type return "ok" : "ok" @@ -40,7 +40,7 @@ function user.test_param(t: type) -> string throws never { return "ok" : "ok" } } -function user.test_return_type(function_name: string) -> type throws never { +function user.test_return_type(function_name: string) -> type throws baml.errors.InvalidArgument { { : never return baml.llm.get_return_type(function_name) : type } diff --git a/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs b/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs index 6958281a94..fb8b02dc2b 100644 --- a/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs +++ b/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs @@ -39,7 +39,9 @@ pub(crate) mod support { ScopeInference, infer_scope_types, render_scope_diagnostics, resolve_class_fields, resolve_type_alias, }, - lower_type_expr::lower_type_expr_in_ns, + lower_type_expr::{ + FnTypeLoweringContext, lower_type_expr_in_ns, lower_type_expr_with_fn_context, + }, }; use baml_project::ProjectDatabase; @@ -1034,6 +1036,7 @@ pub(crate) mod support { format!("<{}>", names.join(", ")) }; + let mut synthetic_display_vars: Vec = Vec::new(); let params: Vec = sig .params .iter() @@ -1048,7 +1051,18 @@ pub(crate) mod support { ) } else { let mut diags = Vec::new(); - lower_type_expr_in_ns(db, ptype, pkg_items, ns, gp, &mut diags) + lower_type_expr_with_fn_context( + db, + ptype, + pkg_items, + ns, + gp, + &mut diags, + &FnTypeLoweringContext::DirectParamRoot { + param_name: pname.clone(), + }, + &mut synthetic_display_vars, + ) }; format!("{}: {}", pname, ty) }) @@ -1075,6 +1089,27 @@ pub(crate) mod support { }) }; + // Compute post-inference effective throws from the function body. + // This captures HOF effect propagation that the pre-inference + // throw_sets cannot see. + let post_inference_throws: Option = { + if let Some(ref fb) = func_body_opt { + if let baml_compiler2_hir::body::FunctionBody::Expr(ref body) = **fb + { + let ty = inference.effective_throws(db, pkg_id, body); + if matches!(ty, baml_compiler2_tir::ty::Ty::Never { .. }) { + None + } else { + Some(ty.to_string()) + } + } else { + None + } + } else { + None + } + }; + let throws = if let Some(t) = &sig.throws { let mut diags = Vec::new(); let declared = @@ -1085,8 +1120,51 @@ pub(crate) mod support { } None => format!(" throws {declared}"), } - } else { + } else if !synthetic_display_vars.is_empty() { + // Function has implicit effect vars from callback params. + // Union the effect vars with the body's own concrete throws. + let mut all_throws: Vec = synthetic_display_vars + .iter() + .map(|v| v.to_string()) + .collect(); + // Add body's own throws (post_inference_throws excludes TypeVars, + // so these are the function's own concrete throws). + if let Some(ref body_throws) = post_inference_throws { + // Split the body throws string and add each component. + for component in body_throws.split(" | ") { + let trimmed = component.trim(); + if !trimmed.is_empty() + && !all_throws.contains(&trimmed.to_string()) + { + all_throws.push(trimmed.to_string()); + } + } + } + // Sort for deterministic output: effect vars first, then others. + all_throws.sort_by(|a, b| { + let a_is_effect = a.starts_with("__throws_"); + let b_is_effect = b.starts_with("__throws_"); + match (a_is_effect, b_is_effect) { + (true, false) => std::cmp::Ordering::Greater, + (false, true) => std::cmp::Ordering::Less, + _ => a.cmp(b), + } + }); + let throws_str = all_throws.join(" | "); match &inferred_throws { + Some(inferred) => { + format!(" throws {throws_str} infers {inferred}") + } + None => format!(" throws {throws_str}"), + } + } else { + // No explicit throws, no effect vars. + // Use post-inference effective throws if available, + // falling back to pre-inference transitive throws. + match post_inference_throws + .as_deref() + .or(inferred_throws.as_deref()) + { Some(inferred) => format!(" throws {inferred}"), None => " throws never".to_string(), } From 96087108c15412172f9b04acc51f80a3d52f644f Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 06:04:37 -0500 Subject: [PATCH 04/26] baml_language: add stored callback enforcement for typed rethrows Implement Phase 4 of typed rethrows for higher-order functions: - Add StoredFunctionRequiresExplicitThrows diagnostic for when a throwing function is assigned to a stored position typed throws never - Enforce throws compatibility in lambda bidirectional checking - Resolve type aliases before lambda checking to handle type alias cases - Add comprehensive test coverage for class fields, local variables, type aliases, and return positions Co-Authored-By: Claude Opus 4.5 --- .../crates/baml_compiler2_tir/src/builder.rs | 113 +++- .../baml_compiler2_tir/src/infer_context.rs | 15 + .../crates/baml_lsp2_actions/src/check.rs | 2 + .../stored_callback_enforcement.baml | 82 +++ ...01_lexer__stored_callback_enforcement.snap | 500 ++++++++++++++++++ ...2_parser__stored_callback_enforcement.snap | 467 ++++++++++++++++ ...l_tests__function_type_throws__03_hir.snap | 40 ++ ...tests__function_type_throws__04_5_mir.snap | 290 ++++++++++ ...l_tests__function_type_throws__04_tir.snap | 132 +++++ ..._function_type_throws__05_diagnostics.snap | 40 ++ ...sts__function_type_throws__06_codegen.snap | 64 +++ ...ormatter__stored_callback_enforcement.snap | 5 + 12 files changed, 1724 insertions(+), 26 deletions(-) create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__stored_callback_enforcement.snap diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index a765fe139e..1ad4f8a5bf 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -914,14 +914,8 @@ impl<'db> TypeInferenceBuilder<'db> { // let the subtype check report the error. let inferred = self.infer_expr(expr_id, body); if !self.is_subtype(&inferred, expected) { - self.context.report( - TirTypeError::TypeMismatch { - expected: expected.clone(), - got: inferred.clone(), - }, - expr_id, - Vec::new(), - ); + // Literals won't be functions, but use the helper for consistency + self.report_type_mismatch(expected, &inferred, expr_id); } inferred } @@ -1119,7 +1113,9 @@ impl<'db> TypeInferenceBuilder<'db> { } // Lambda: bidirectional checking against expected function type Expr::Lambda(func_def) => { - match expected { + // Resolve type aliases so we can decompose function types + let expected_resolved = self.resolve_alias_chain(expected); + match &expected_resolved { Ty::Function { params: expected_params, ret: expected_ret, @@ -1233,9 +1229,23 @@ impl<'db> TypeInferenceBuilder<'db> { let result = Ty::Function { params: param_tys, ret: Box::new(ret_ty), - throws: Box::new(throws_ty), + throws: Box::new(throws_ty.clone()), attr: TyAttr::default(), }; + + // Check throws compatibility: if expected is `throws never` but + // lambda body throws, report the stored function diagnostic + if matches!(expected_throws.as_ref(), Ty::Never { .. }) + && !matches!(throws_ty, Ty::Never { .. }) + { + self.context.report_simple( + TirTypeError::StoredFunctionRequiresExplicitThrows { + actual_throws: throws_ty, + }, + expr_id, + ); + } + self.record_expr_type(expr_id, result.clone()); result } @@ -1265,14 +1275,8 @@ impl<'db> TypeInferenceBuilder<'db> { self.context .report_simple(TirTypeError::VoidUsedAsValue, expr_id); } else if !self.is_subtype(&inferred, expected) { - self.context.report( - TirTypeError::TypeMismatch { - expected: expected.clone(), - got: inferred.clone(), - }, - expr_id, - Vec::new(), - ); + // Use specialized diagnostic for stored function throws mismatch + self.report_type_mismatch(expected, &inferred, expr_id); } inferred } @@ -1437,14 +1441,8 @@ impl<'db> TypeInferenceBuilder<'db> { && !matches!(value_ty, Ty::Unknown { .. } | Ty::Error { .. }) && !self.is_subtype(&value_ty, decl_ty) { - self.context.report( - TirTypeError::TypeMismatch { - expected: decl_ty.clone(), - got: value_ty.clone(), - }, - *value, - Vec::new(), - ); + // Use specialized diagnostic for stored function throws mismatch + self.report_type_mismatch(decl_ty, &value_ty, *value); } // Update the local to the assigned value's type (invalidates narrowing) if let Expr::Path(segments) = &body.exprs[*target] { @@ -4488,6 +4486,69 @@ impl<'db> TypeInferenceBuilder<'db> { crate::normalize::is_subtype_of(sub, sup, &self.aliases) } + /// Report a type mismatch, with special handling for stored function throws. + /// + /// When assigning a throwing function to a stored position (class field, type alias, + /// return type, local) that defaults to `throws never`, emit a more specific diagnostic + /// explaining that explicit `throws` annotation is required. + fn report_type_mismatch(&self, expected: &Ty, got: &Ty, at: ExprId) { + // Resolve type aliases for comparison + let expected_resolved = self.resolve_alias_chain(expected); + let got_resolved = self.resolve_alias_chain(got); + + // Check for stored function throws mismatch + if let ( + Ty::Function { + throws: expected_throws, + .. + }, + Ty::Function { + throws: actual_throws, + .. + }, + ) = (&expected_resolved, &got_resolved) + { + // If expected is `throws never` but actual throws something else, + // emit the specific stored-function diagnostic + if matches!(expected_throws.as_ref(), Ty::Never { .. }) + && !matches!(actual_throws.as_ref(), Ty::Never { .. }) + { + self.context.report_simple( + TirTypeError::StoredFunctionRequiresExplicitThrows { + actual_throws: actual_throws.as_ref().clone(), + }, + at, + ); + return; + } + } + + // Default: report generic type mismatch + self.context.report( + TirTypeError::TypeMismatch { + expected: expected.clone(), + got: got.clone(), + }, + at, + Vec::new(), + ); + } + + /// Resolve a type alias chain to its underlying type (up to a depth limit). + fn resolve_alias_chain(&self, ty: &Ty) -> Ty { + let mut resolved = ty.clone(); + for _ in 0..64 { + match &resolved { + Ty::TypeAlias(qtn, _) => match self.aliases.get(qtn) { + Some(expanded) => resolved = expanded.clone(), + None => break, + }, + _ => break, + } + } + resolved + } + fn infer_binary_op( &mut self, op: baml_compiler2_ast::BinaryOp, diff --git a/baml_language/crates/baml_compiler2_tir/src/infer_context.rs b/baml_language/crates/baml_compiler2_tir/src/infer_context.rs index 9b4dce17cb..5de680aff2 100644 --- a/baml_language/crates/baml_compiler2_tir/src/infer_context.rs +++ b/baml_language/crates/baml_compiler2_tir/src/infer_context.rs @@ -142,6 +142,14 @@ pub enum TirTypeError { /// The full expression text (e.g. `a.name`) expr: String, }, + /// A throwing function is assigned to a stored position (class field, type alias, + /// return type, local variable) that defaults to `throws never`. + /// + /// The fix is to add an explicit `throws` annotation to the stored function type. + StoredFunctionRequiresExplicitThrows { + /// The inferred throws type of the actual function being stored. + actual_throws: Ty, + }, } impl fmt::Display for TirTypeError { @@ -310,6 +318,13 @@ impl fmt::Display for TirTypeError { "did you mean `{suggested}`? `{expr}` does not handle the case when `{base}` is null" ) } + TirTypeError::StoredFunctionRequiresExplicitThrows { actual_throws } => { + write!( + f, + "function that `throws {actual_throws}` cannot be stored in a position typed `throws never`; \ + add an explicit `throws {actual_throws}` annotation to the stored function type" + ) + } } } } diff --git a/baml_language/crates/baml_lsp2_actions/src/check.rs b/baml_language/crates/baml_lsp2_actions/src/check.rs index e58e753371..cc66159a83 100644 --- a/baml_language/crates/baml_lsp2_actions/src/check.rs +++ b/baml_language/crates/baml_lsp2_actions/src/check.rs @@ -299,6 +299,8 @@ fn tir_type_error_to_diagnostic_id( TirTypeError::SuggestNullCoalesce { .. } => DiagnosticId::InvalidOperator, TirTypeError::NullCoalesceWithNull { .. } => DiagnosticId::InvalidOperator, TirTypeError::NullableMemberAccess { .. } => DiagnosticId::TypeMismatch, + // Stored function throws mismatch + TirTypeError::StoredFunctionRequiresExplicitThrows { .. } => DiagnosticId::TypeMismatch, } } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml b/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml new file mode 100644 index 0000000000..f02b081d56 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml @@ -0,0 +1,82 @@ +// === Stored callback enforcement tests === +// These test that storing a throwing function in a position typed `throws never` produces an error. + +// --- Class field assignments --- + +class StoredPureHandler { + run: () -> null +} + +// ERROR: assigning throwing lambda to field typed throws never +function make_bad_stored_handler() -> StoredPureHandler { + StoredPureHandler { run: () -> null { throw "error" } } +} + +// OK: assigning pure lambda to field typed throws never +function make_good_stored_handler() -> StoredPureHandler { + StoredPureHandler { run: () -> null { null } } +} + +// --- Class with explicit throws on field --- + +class StoredThrowingHandler { + run: () -> null throws string +} + +// OK: assigning throwing lambda to field typed throws string +function make_stored_throwing_handler() -> StoredThrowingHandler { + StoredThrowingHandler { run: () -> null { throw "error" } } +} + +// --- Local variable assignment --- + +// ERROR: assigning throwing lambda to variable typed throws never +function test_stored_local_error() -> null { + let f: () -> null = () -> null { throw "oops" } + null +} + +// OK: assigning pure lambda to variable typed throws never +function test_stored_local_ok() -> null { + let f: () -> null = () -> null { null } + null +} + +// OK: assigning throwing lambda to variable typed throws string +function test_stored_local_with_throws() -> null { + let f: () -> null throws string = () -> null { throw "oops" } + null +} + +// --- Type alias usage --- + +type StoredPureCb = () -> int + +// ERROR: assigning throwing lambda to type alias that defaults to throws never +function test_stored_alias_error() -> null { + let cb: StoredPureCb = () -> int { throw "error" } + null +} + +// OK: assigning pure lambda +function test_stored_alias_ok() -> null { + let cb: StoredPureCb = () -> int { 42 } + null +} + +// --- Return position --- + +// ERROR: returning throwing closure when return type defaults to throws never +function make_stored_closure_bad() -> (() -> int) { + return () -> int { throw "oops" } +} + +// OK: returning pure closure +function make_stored_closure_good() -> (() -> int) { + return () -> int { 42 } +} + +// OK: returning throwing closure when return type has explicit throws +function make_stored_closure_with_throws() -> (() -> int throws string) { + return () -> int { throw "oops" } +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap new file mode 100644 index 0000000000..5249184d9d --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap @@ -0,0 +1,500 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Stored" +Word "callback" +Word "enforcement" +Word "tests" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "These" +Test "test" +Word "that" +Word "storing" +Word "a" +Word "throwing" +Function "function" +In "in" +Word "a" +Word "position" +Word "typed" +Error "`" +Throws "throws" +Word "never" +Error "`" +Word "produces" +Word "an" +Word "error" +Dot "." +Slash "/" +Slash "/" +MinusMinus "--" +Minus "-" +Word "Class" +Word "field" +Word "assignments" +MinusMinus "--" +Minus "-" +Class "class" +Word "StoredPureHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +RBrace "}" +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "assigning" +Word "throwing" +Word "lambda" +Word "to" +Word "field" +Word "typed" +Throws "throws" +Word "never" +Function "function" +Word "make_bad_stored_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "StoredPureHandler" +LBrace "{" +Word "StoredPureHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Quote "\"" +Word "error" +Quote "\"" +RBrace "}" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +Word "OK" +Colon ":" +Word "assigning" +Word "pure" +Word "lambda" +Word "to" +Word "field" +Word "typed" +Throws "throws" +Word "never" +Function "function" +Word "make_good_stored_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "StoredPureHandler" +LBrace "{" +Word "StoredPureHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "null" +RBrace "}" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +MinusMinus "--" +Minus "-" +Word "Class" +Word "with" +Word "explicit" +Throws "throws" +Word "on" +Word "field" +MinusMinus "--" +Minus "-" +Class "class" +Word "StoredThrowingHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Throws "throws" +Word "string" +RBrace "}" +Slash "/" +Slash "/" +Word "OK" +Colon ":" +Word "assigning" +Word "throwing" +Word "lambda" +Word "to" +Word "field" +Word "typed" +Throws "throws" +Word "string" +Function "function" +Word "make_stored_throwing_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "StoredThrowingHandler" +LBrace "{" +Word "StoredThrowingHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Quote "\"" +Word "error" +Quote "\"" +RBrace "}" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +MinusMinus "--" +Minus "-" +Word "Local" +Word "variable" +Word "assignment" +MinusMinus "--" +Minus "-" +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "assigning" +Word "throwing" +Word "lambda" +Word "to" +Word "variable" +Word "typed" +Throws "throws" +Word "never" +Function "function" +Word "test_stored_local_error" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Quote "\"" +Word "oops" +Quote "\"" +RBrace "}" +Word "null" +RBrace "}" +Slash "/" +Slash "/" +Word "OK" +Colon ":" +Word "assigning" +Word "pure" +Word "lambda" +Word "to" +Word "variable" +Word "typed" +Throws "throws" +Word "never" +Function "function" +Word "test_stored_local_ok" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "null" +RBrace "}" +Word "null" +RBrace "}" +Slash "/" +Slash "/" +Word "OK" +Colon ":" +Word "assigning" +Word "throwing" +Word "lambda" +Word "to" +Word "variable" +Word "typed" +Throws "throws" +Word "string" +Function "function" +Word "test_stored_local_with_throws" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Throws "throws" +Word "string" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Quote "\"" +Word "oops" +Quote "\"" +RBrace "}" +Word "null" +RBrace "}" +Slash "/" +Slash "/" +MinusMinus "--" +Minus "-" +Word "Type" +Word "alias" +Word "usage" +MinusMinus "--" +Minus "-" +Word "type" +Word "StoredPureCb" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "assigning" +Word "throwing" +Word "lambda" +Word "to" +Word "type" +Word "alias" +Word "that" +Word "defaults" +Word "to" +Throws "throws" +Word "never" +Function "function" +Word "test_stored_alias_error" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "cb" +Colon ":" +Word "StoredPureCb" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "error" +Quote "\"" +RBrace "}" +Word "null" +RBrace "}" +Slash "/" +Slash "/" +Word "OK" +Colon ":" +Word "assigning" +Word "pure" +Word "lambda" +Function "function" +Word "test_stored_alias_ok" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "cb" +Colon ":" +Word "StoredPureCb" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +Word "null" +RBrace "}" +Slash "/" +Slash "/" +MinusMinus "--" +Minus "-" +Word "Return" +Word "position" +MinusMinus "--" +Minus "-" +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "returning" +Word "throwing" +Word "closure" +Word "when" +Return "return" +Word "type" +Word "defaults" +Word "to" +Throws "throws" +Word "never" +Function "function" +Word "make_stored_closure_bad" +LParen "(" +RParen ")" +Arrow "->" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +LBrace "{" +Return "return" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "oops" +Quote "\"" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +Word "OK" +Colon ":" +Word "returning" +Word "pure" +Word "closure" +Function "function" +Word "make_stored_closure_good" +LParen "(" +RParen ")" +Arrow "->" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +LBrace "{" +Return "return" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +Word "OK" +Colon ":" +Word "returning" +Word "throwing" +Word "closure" +Word "when" +Return "return" +Word "type" +Word "has" +Word "explicit" +Throws "throws" +Function "function" +Word "make_stored_closure_with_throws" +LParen "(" +RParen ")" +Arrow "->" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +RParen ")" +LBrace "{" +Return "return" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "oops" +Quote "\"" +RBrace "}" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap new file mode 100644 index 0000000000..03836d3b8a --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap @@ -0,0 +1,467 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + CLASS_DEF + KW_CLASS "class" + WORD "StoredPureHandler" + L_BRACE "{" + FIELD + WORD "run" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_bad_stored_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "StoredPureHandler" + WORD "StoredPureHandler" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OBJECT_LITERAL + WORD "StoredPureHandler" + L_BRACE "{" + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "error" + QUOTE """ + WORD "error" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_good_stored_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "StoredPureHandler" + WORD "StoredPureHandler" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OBJECT_LITERAL + WORD "StoredPureHandler" + L_BRACE "{" + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR "{ null }" + L_BRACE "{" + WORD "null" + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + CLASS_DEF + KW_CLASS "class" + WORD "StoredThrowingHandler" + L_BRACE "{" + FIELD + WORD "run" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_stored_throwing_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "StoredThrowingHandler" + WORD "StoredThrowingHandler" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OBJECT_LITERAL + WORD "StoredThrowingHandler" + L_BRACE "{" + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "error" + QUOTE """ + WORD "error" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_stored_local_error" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "oops" + QUOTE """ + WORD "oops" + QUOTE """ + R_BRACE "}" + WORD "null" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_stored_local_ok" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR "{ null }" + L_BRACE "{" + WORD "null" + R_BRACE "}" + WORD "null" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_stored_local_with_throws" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "oops" + QUOTE """ + WORD "oops" + QUOTE """ + R_BRACE "}" + WORD "null" + R_BRACE "}" + TYPE_ALIAS_DEF + WORD "type" + WORD "StoredPureCb" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_stored_alias_error" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "cb" + COLON ":" + TYPE_EXPR "StoredPureCb" + WORD "StoredPureCb" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "error" + QUOTE """ + WORD "error" + QUOTE """ + R_BRACE "}" + WORD "null" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_stored_alias_ok" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "cb" + COLON ":" + TYPE_EXPR "StoredPureCb" + WORD "StoredPureCb" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + WORD "null" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_stored_closure_bad" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + RETURN_STMT + KW_RETURN "return" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "oops" + QUOTE """ + WORD "oops" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_stored_closure_good" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + RETURN_STMT + KW_RETURN "return" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_stored_closure_with_throws" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + RETURN_STMT + KW_RETURN "return" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "oops" + QUOTE """ + WORD "oops" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index b2819f0858..ceb16a5b32 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -225,3 +225,43 @@ function user.test_use_pure() -> int [expr] { function user.test_use_thrower() -> int [expr] { { let f = make_thrower() } f() } +class user.StoredPureHandler { + run: () -> null +} +class user.StoredThrowingHandler { + run: () -> null +} +type user.StoredPureCb = () -> int +function user.make_bad_stored_handler() -> user.StoredPureHandler [expr] { + { } user.StoredPureHandler { run: () -> null { { throw "error" } } } +} +function user.make_good_stored_handler() -> user.StoredPureHandler [expr] { + { } user.StoredPureHandler { run: () -> null { { } null } } +} +function user.make_stored_closure_bad() -> () -> int [expr] { + { return () -> int { { throw "oops" } } } +} +function user.make_stored_closure_good() -> () -> int [expr] { + { return () -> int { { } 42 } } +} +function user.make_stored_closure_with_throws() -> () -> int [expr] { + { return () -> int { { throw "oops" } } } +} +function user.make_stored_throwing_handler() -> user.StoredThrowingHandler [expr] { + { } user.StoredThrowingHandler { run: () -> null { { throw "error" } } } +} +function user.test_stored_alias_error() -> null [expr] { + { let cb: user.StoredPureCb = () -> int { { throw "error" } } } null +} +function user.test_stored_alias_ok() -> null [expr] { + { let cb: user.StoredPureCb = () -> int { { } 42 } } null +} +function user.test_stored_local_error() -> null [expr] { + { let f: () -> null = () -> null { { throw "oops" } } } null +} +function user.test_stored_local_ok() -> null [expr] { + { let f: () -> null = () -> null { { } null } } null +} +function user.test_stored_local_with_throws() -> null [expr] { + { let f: () -> null = () -> null { { throw "oops" } } } null +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index a5969d347e..e138d7e4f8 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -2109,3 +2109,293 @@ fn user.test_use_thrower() -> int { return; } } + +fn user.make_bad_stored_handler() -> StoredPureHandler { + // Locals: + let _0: StoredPureHandler // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = StoredPureHandler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "error"; + } +} + +fn user.make_good_stored_handler() -> StoredPureHandler { + // Locals: + let _0: StoredPureHandler // _0 // return + let _1: () -> null + + bb0: { + _1 = make_closure lambda[0](); + _0 = StoredPureHandler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.make_stored_closure_bad() -> () -> int { + // Locals: + let _0: () -> int // _0 // return + + bb0: { + _0 = make_closure lambda[0](); + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "oops"; + } +} + +fn user.make_stored_closure_good() -> () -> int { + // Locals: + let _0: () -> int // _0 // return + + bb0: { + _0 = make_closure lambda[0](); + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.make_stored_closure_with_throws() -> () -> int { + // Locals: + let _0: () -> int // _0 // return + + bb0: { + _0 = make_closure lambda[0](); + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "oops"; + } +} + +fn user.make_stored_throwing_handler() -> StoredThrowingHandler { + // Locals: + let _0: StoredThrowingHandler // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = StoredThrowingHandler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "error"; + } +} + +fn user.test_stored_alias_error() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "error"; + } +} + +fn user.test_stored_alias_ok() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_stored_local_error() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "oops"; + } +} + +fn user.test_stored_local_ok() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_stored_local_with_throws() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "oops"; + } +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index aa9b79d7f6..917c8f768c 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -708,3 +708,135 @@ function user.test_use_thrower() -> int throws string { f() : int } } +class user.StoredPureHandler { + run: () -> null +} +function user.make_bad_stored_handler() -> user.StoredPureHandler throws never { + { : user.StoredPureHandler + StoredPureHandler { run: () -> null { ... } } : user.StoredPureHandler + } + !! 378..407: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type +} +lambda user.make_bad_stored_handler { +} +function user.make_good_stored_handler() -> user.StoredPureHandler throws never { + { : user.StoredPureHandler + StoredPureHandler { run: () -> null { ... } } : user.StoredPureHandler + } +} +lambda user.make_good_stored_handler { +} +class user.StoredThrowingHandler { + run: () -> null throws string +} +function user.make_stored_throwing_handler() -> user.StoredThrowingHandler throws never { + { : user.StoredThrowingHandler + StoredThrowingHandler { run: () -> null { ... } } : user.StoredThrowingHandler + } +} +lambda user.make_stored_throwing_handler { +} +function user.test_stored_local_error() -> null throws never { + { : null + let f = : () -> never throws string + () -> null { ... } : () -> never throws string + { + throw "oops" + } + null : null + } + !! 1059..1087: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type +} +lambda user.test_stored_local_error { +} +function user.test_stored_local_ok() -> null throws never { + { : null + let f = : () -> null + () -> null { ... } : () -> null + { + null + } + null : null + } +} +lambda user.test_stored_local_ok { +} +function user.test_stored_local_with_throws() -> null throws never { + { : null + let f = : () -> never throws string + () -> null { ... } : () -> never throws string + { + throw "oops" + } + null : null + } +} +lambda user.test_stored_local_with_throws { +} +type user.StoredPureCb = () -> int +function user.test_stored_alias_error() -> null throws never { + { : null + let cb = : () -> never throws string + () -> int { ... } : () -> never throws string + { + throw "error" + } + null : null + } + !! 1651..1679: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type +} +lambda user.test_stored_alias_error { +} +function user.test_stored_alias_ok() -> null throws never { + { : null + let cb = : () -> 42 + () -> int { ... } : () -> 42 + { + 42 + } + null : null + } +} +lambda user.test_stored_alias_ok { +} +function user.make_stored_closure_bad() -> () -> int throws never { + { : never + return : () -> never throws string + () -> int { ... } : () -> never throws string + { + throw "oops" + } + } + !! 1980..2007: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type +} +lambda user.make_stored_closure_bad { +} +function user.make_stored_closure_good() -> () -> int throws never { + { : never + return : () -> 42 + () -> int { ... } : () -> 42 + { + 42 + } + } +} +lambda user.make_stored_closure_good { +} +function user.make_stored_closure_with_throws() -> () -> int throws string throws never { + { : never + return : () -> never throws string + () -> int { ... } : () -> never throws string + { + throw "oops" + } + } +} +lambda user.make_stored_closure_with_throws { +} +class user.StoredPureHandler$stream { + run: unknown +} +class user.StoredThrowingHandler$stream { + run: unknown +} +type user.StoredPureCb$stream = unknown diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index ca1fc4928c..64b9a2d522 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -113,3 +113,43 @@ source: crates/baml_tests/src/generated_tests.rs │ │ Note: Error code: E0001 ────╯ + + [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + ╭─[ stored_callback_enforcement.baml:12:27 ] + │ + 12 │ StoredPureHandler { run: () -> null { throw "error" } } + │ ──────────────┬────────────── + │ ╰──────────────── function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + ╭─[ stored_callback_enforcement.baml:35:22 ] + │ + 35 │ let f: () -> null = () -> null { throw "oops" } + │ ──────────────┬───────────── + │ ╰─────────────── function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + ╭─[ stored_callback_enforcement.baml:57:25 ] + │ + 57 │ let cb: StoredPureCb = () -> int { throw "error" } + │ ──────────────┬───────────── + │ ╰─────────────── function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + ╭─[ stored_callback_enforcement.baml:71:9 ] + │ + 71 │ return () -> int { throw "oops" } + │ ─────────────┬───────────── + │ ╰─────────────── function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + │ + │ Note: Error code: E0001 +────╯ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index caba8aa39b..8fb01140af 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -114,6 +114,22 @@ function user.helper_with_body_throw() -> int { throw } +function user.make_bad_stored_handler() -> StoredPureHandler { + alloc_instance StoredPureHandler + copy 0 + make_closure ., 0 + store_field .run + return +} + +function user.make_good_stored_handler() -> StoredPureHandler { + alloc_instance StoredPureHandler + copy 0 + make_closure ., 0 + store_field .run + return +} + function user.make_pure() -> () -> int { load_const null return @@ -127,6 +143,29 @@ function user.make_pure_handler() -> PureHandler { return } +function user.make_stored_closure_bad() -> () -> int { + make_closure ., 0 + return +} + +function user.make_stored_closure_good() -> () -> int { + make_closure ., 0 + return +} + +function user.make_stored_closure_with_throws() -> () -> int { + make_closure ., 0 + return +} + +function user.make_stored_throwing_handler() -> StoredThrowingHandler { + alloc_instance StoredThrowingHandler + copy 0 + make_closure ., 0 + store_field .run + return +} + function user.make_thrower() -> () -> int { load_const null return @@ -441,6 +480,31 @@ function user.test_run_throwing() -> int { return } +function user.test_stored_alias_error() -> null { + load_const null + return +} + +function user.test_stored_alias_ok() -> null { + load_const null + return +} + +function user.test_stored_local_error() -> null { + load_const null + return +} + +function user.test_stored_local_ok() -> null { + load_const null + return +} + +function user.test_stored_local_with_throws() -> null { + load_const null + return +} + function user.test_throwing_int() -> string { make_closure ., 0 call_indirect diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__stored_callback_enforcement.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__stored_callback_enforcement.snap new file mode 100644 index 0000000000..fbd49f4145 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__stored_callback_enforcement.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Unexpected additional element at stored_callback_enforcement.baml:23:18 From 4282b748b9d632722e9190a7ac88499a74abc39b Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 06:46:44 -0500 Subject: [PATCH 05/26] baml_language: add stdlib migration and throws regression coverage --- .../baml_std/testing/registry.baml | 4 +- .../function_type_throws/compose_hof.baml | 41 + .../enum_class_throws.baml | 203 +++ .../baml_tests____testing_std____04_tir.snap | 8 +- ...on_type_throws__01_lexer__compose_hof.snap | 263 +++ ...e_throws__01_lexer__enum_class_throws.snap | 1130 +++++++++++++ ...n_type_throws__02_parser__compose_hof.snap | 358 ++++ ..._throws__02_parser__enum_class_throws.snap | 1440 +++++++++++++++++ ...l_tests__function_type_throws__03_hir.snap | 117 ++ ...tests__function_type_throws__04_5_mir.snap | 1057 ++++++++++++ ...l_tests__function_type_throws__04_tir.snap | 381 +++++ ..._function_type_throws__05_diagnostics.snap | 122 ++ ...sts__function_type_throws__06_codegen.snap | 460 ++++++ ...ype_throws__10_formatter__compose_hof.snap | 5 + ...rows__10_formatter__enum_class_throws.snap | 5 + 15 files changed, 5588 insertions(+), 6 deletions(-) create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/enum_class_throws.baml create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__enum_class_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__compose_hof.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__enum_class_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__compose_hof.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__enum_class_throws.snap diff --git a/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml b/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml index ab692ccd0f..80b9de728d 100644 --- a/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml +++ b/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml @@ -4,13 +4,13 @@ class TestRegistration { name string - body () -> null + body () -> null throws unknown runner TestRunner? } class TestSetRegistration { name string - collector (TestCollector) -> null + collector (TestCollector) -> null throws unknown runner TestSetRunner? } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml b/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml new file mode 100644 index 0000000000..587f615da8 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml @@ -0,0 +1,41 @@ +// === Compose HOF: nested higher-order functions with generics === + +// Compose two functions, propagating throws from both +function compose( + f: (A) -> B, + g: (B) -> C +) -> (A) -> C { + (a: A) -> C { g(f(a)) } +} + +// Pure composition +function test_compose_pure() -> (int) -> string { + compose( + (x: int) -> int { x * 2 }, + (y: int) -> string { "result" } + ) +} + +// First function throws +function test_compose_first_throws() -> (int) -> string { + compose( + (x: int) -> int { throw "f failed" }, + (y: int) -> string { "result" } + ) +} + +// Second function throws +function test_compose_second_throws() -> (int) -> string { + compose( + (x: int) -> int { x * 2 }, + (y: int) -> string { throw "g failed" } + ) +} + +// Both functions throw different types +function test_compose_both_throw() -> (int) -> string { + compose( + (x: int) -> int { throw "string error" }, + (y: int) -> string { throw 42 } + ) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/enum_class_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/enum_class_throws.baml new file mode 100644 index 0000000000..7055b43533 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/enum_class_throws.baml @@ -0,0 +1,203 @@ +// === Enum, class, and union throws coverage === + +// --- Enum definitions --- + +enum ErrorKind { + NotFound + Unauthorized + ValidationFailed + RateLimited +} + +// --- Class definitions --- + +class ApiError { + code: int + message: string +} + +class ValidationError { + field: string + reason: string +} + +// === Throwing enums === + +// Throw a specific enum variant +function throw_enum_variant(x: int) -> string throws ErrorKind { + if (x == 0) { throw ErrorKind.NotFound } + if (x == 1) { throw ErrorKind.Unauthorized } + "ok" +} + +// Lambda that throws an enum variant +function test_lambda_throws_enum() -> int { + let f = () -> int { throw ErrorKind.ValidationFailed } + f() +} + +// === Throwing classes === + +// Throw a class instance +function throw_class_instance(x: int) -> string throws ApiError { + if (x < 0) { + throw ApiError { code: 500, message: "internal error" } + } + "ok" +} + +// Lambda that throws a class +function test_lambda_throws_class() -> int { + let f = () -> int { + throw ValidationError { field: "email", reason: "invalid format" } + } + f() +} + +// === Union throws types === + +// Union of enum variants (single enum type covers all) +function throw_various_errors(x: int) -> string throws ErrorKind { + match (x) { + 0 => throw ErrorKind.NotFound, + 1 => throw ErrorKind.Unauthorized, + 2 => throw ErrorKind.ValidationFailed, + _ => "ok" + } +} + +// Union of different classes +function throw_mixed_classes(x: int) -> string throws ApiError | ValidationError { + if (x == 0) { + throw ApiError { code: 404, message: "not found" } + } + if (x == 1) { + throw ValidationError { field: "id", reason: "required" } + } + "ok" +} + +// Union of enum and class +function throw_enum_or_class(x: int) -> string throws ErrorKind | ApiError { + if (x == 0) { throw ErrorKind.RateLimited } + if (x == 1) { throw ApiError { code: 503, message: "unavailable" } } + "ok" +} + +// Union of primitive, enum, and class +function throw_any_error(x: int) -> string throws string | ErrorKind | ApiError { + match (x) { + 0 => throw "simple string error", + 1 => throw ErrorKind.NotFound, + 2 => throw ApiError { code: 400, message: "bad request" }, + _ => "ok" + } +} + +// === HOF with enum/class throws === + +// HOF with callback that throws enum +function apply_may_throw_enum(f: () -> int throws ErrorKind) -> int { + f() +} + +function test_apply_enum_thrower() -> int { + apply_may_throw_enum(() -> int { throw ErrorKind.Unauthorized }) +} + +// HOF with callback that throws class +function apply_may_throw_class(f: () -> int throws ApiError) -> int { + f() +} + +function test_apply_class_thrower() -> int { + apply_may_throw_class(() -> int { + throw ApiError { code: 401, message: "unauthorized" } + }) +} + +// HOF with implicit rethrows - enum callback +function apply_generic_enum(f: () -> int) -> int { f() } + +function test_rethrows_enum() -> int { + apply_generic_enum(() -> int { throw ErrorKind.NotFound }) +} + +// HOF with implicit rethrows - class callback +function apply_generic_class(f: () -> int) -> int { f() } + +function test_rethrows_class() -> int { + apply_generic_class(() -> int { + throw ValidationError { field: "name", reason: "too short" } + }) +} + +// === Catch patterns with enums/classes === + +function catch_enum_variants(x: int) -> string { + throw_enum_variant(x) catch (e) { + ErrorKind.NotFound => "not found", + ErrorKind.Unauthorized => "unauthorized", + _ => "other error" + } +} + +function catch_class_error(x: int) -> string { + throw_class_instance(x) catch (e) { + _: ApiError => "api error: " + e.message + } +} + +function catch_mixed_errors(x: int) -> string { + throw_enum_or_class(x) catch (e) { + ErrorKind.RateLimited => "rate limited", + _: ApiError => "api error", + _ => "unknown" + } +} + +// === Stored callbacks with enum/class throws === + +class EnumThrowingHandler { + run: () -> null throws ErrorKind +} + +class ClassThrowingHandler { + run: () -> null throws ApiError +} + +class MixedThrowingHandler { + run: () -> null throws ErrorKind | ApiError +} + +function make_enum_handler() -> EnumThrowingHandler { + EnumThrowingHandler { run: () -> null { throw ErrorKind.NotFound } } +} + +function make_class_handler() -> ClassThrowingHandler { + ClassThrowingHandler { + run: () -> null { throw ApiError { code: 500, message: "fail" } } + } +} + +// === Type aliases with enum/class throws === + +type EnumThrower = () -> int throws ErrorKind +type ClassThrower = () -> int throws ApiError +type MixedThrower = () -> int throws ErrorKind | ApiError | string + +function use_enum_thrower(f: EnumThrower) -> int { f() } +function use_class_thrower(f: ClassThrower) -> int { f() } +function use_mixed_thrower(f: MixedThrower) -> int { f() } + +function test_type_alias_enum() -> int { + use_enum_thrower(() -> int { throw ErrorKind.ValidationFailed }) +} + +function test_type_alias_class() -> int { + use_class_thrower(() -> int { throw ApiError { code: 422, message: "invalid" } }) +} + +function test_type_alias_mixed() -> int { + use_mixed_thrower(() -> int { throw "string error" }) +} diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap index 06a79c288e..cfe42b2e21 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap @@ -8,12 +8,12 @@ source: crates/baml_tests/src/generated_tests.rs --- /testing/registry.baml --- class testing.TestRegistration { name: string - body: () -> null + body: () -> null throws unknown runner: testing.TestRunner? } class testing.TestSetRegistration { name: string - collector: (testing.TestCollector) -> null + collector: (testing.TestCollector) -> null throws unknown runner: testing.TestSetRunner? } class testing.TestCollector { @@ -160,7 +160,7 @@ function testing.TestRegistry.serialize(self: testing.TestRegistry) -> testing.S return items : testing.SerializedTestDef[] } } -function testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) -> testing.TestReport throws string | unknown { +function testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) -> testing.TestReport throws string | unknown | unknown { { : never for t in self.collector.tests { : void @@ -180,7 +180,7 @@ function testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) throw "Test not found: " + name : string } } -function testing.TestRegistry.expand_set(self: testing.TestRegistry, name: string) -> testing.SerializedTestDef[] throws string | unknown { +function testing.TestRegistry.expand_set(self: testing.TestRegistry, name: string) -> testing.SerializedTestDef[] throws string | unknown | unknown { { : never for ts in self.collector.testsets { : void diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap new file mode 100644 index 0000000000..f178332743 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap @@ -0,0 +1,263 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Compose" +Word "HOF" +Colon ":" +Word "nested" +Word "higher-order" +Word "functions" +Word "with" +Word "generics" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Compose" +Word "two" +Word "functions" +Comma "," +Word "propagating" +Throws "throws" +Word "from" +Word "both" +Function "function" +Word "compose" +Less "<" +Word "A" +Comma "," +Word "B" +Comma "," +Word "C" +Greater ">" +LParen "(" +Word "f" +Colon ":" +LParen "(" +Word "A" +RParen ")" +Arrow "->" +Word "B" +Comma "," +Word "g" +Colon ":" +LParen "(" +Word "B" +RParen ")" +Arrow "->" +Word "C" +RParen ")" +Arrow "->" +LParen "(" +Word "A" +RParen ")" +Arrow "->" +Word "C" +LBrace "{" +LParen "(" +Word "a" +Colon ":" +Word "A" +RParen ")" +Arrow "->" +Word "C" +LBrace "{" +Word "g" +LParen "(" +Word "f" +LParen "(" +Word "a" +RParen ")" +RParen ")" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +Word "Pure" +Word "composition" +Function "function" +Word "test_compose_pure" +LParen "(" +RParen ")" +Arrow "->" +LParen "(" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "compose" +LParen "(" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "x" +Star "*" +IntegerLiteral "2" +RBrace "}" +Comma "," +LParen "(" +Word "y" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Quote "\"" +Word "result" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "First" +Function "function" +Throws "throws" +Function "function" +Word "test_compose_first_throws" +LParen "(" +RParen ")" +Arrow "->" +LParen "(" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "compose" +LParen "(" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "f" +Word "failed" +Quote "\"" +RBrace "}" +Comma "," +LParen "(" +Word "y" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Quote "\"" +Word "result" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Second" +Function "function" +Throws "throws" +Function "function" +Word "test_compose_second_throws" +LParen "(" +RParen ")" +Arrow "->" +LParen "(" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "compose" +LParen "(" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "x" +Star "*" +IntegerLiteral "2" +RBrace "}" +Comma "," +LParen "(" +Word "y" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Throw "throw" +Quote "\"" +Word "g" +Word "failed" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Both" +Word "functions" +Throw "throw" +Word "different" +Word "types" +Function "function" +Word "test_compose_both_throw" +LParen "(" +RParen ")" +Arrow "->" +LParen "(" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "compose" +LParen "(" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "string" +Word "error" +Quote "\"" +RBrace "}" +Comma "," +LParen "(" +Word "y" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__enum_class_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__enum_class_throws.snap new file mode 100644 index 0000000000..a7588fe63b --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__enum_class_throws.snap @@ -0,0 +1,1130 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Enum" +Comma "," +Class "class" +Comma "," +Word "and" +Word "union" +Throws "throws" +Word "coverage" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +MinusMinus "--" +Minus "-" +Word "Enum" +Word "definitions" +MinusMinus "--" +Minus "-" +Enum "enum" +Word "ErrorKind" +LBrace "{" +Word "NotFound" +Word "Unauthorized" +Word "ValidationFailed" +Word "RateLimited" +RBrace "}" +Slash "/" +Slash "/" +MinusMinus "--" +Minus "-" +Word "Class" +Word "definitions" +MinusMinus "--" +Minus "-" +Class "class" +Word "ApiError" +LBrace "{" +Word "code" +Colon ":" +Word "int" +Word "message" +Colon ":" +Word "string" +RBrace "}" +Class "class" +Word "ValidationError" +LBrace "{" +Word "field" +Colon ":" +Word "string" +Word "reason" +Colon ":" +Word "string" +RBrace "}" +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Throwing" +Word "enums" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Throw" +Word "a" +Word "specific" +Enum "enum" +Word "variant" +Function "function" +Word "throw_enum_variant" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +Throws "throws" +Word "ErrorKind" +LBrace "{" +If "if" +LParen "(" +Word "x" +EqualsEquals "==" +IntegerLiteral "0" +RParen ")" +LBrace "{" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "NotFound" +RBrace "}" +If "if" +LParen "(" +Word "x" +EqualsEquals "==" +IntegerLiteral "1" +RParen ")" +LBrace "{" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "Unauthorized" +RBrace "}" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +Slash "/" +Slash "/" +Word "Lambda" +Word "that" +Throws "throws" +Word "an" +Enum "enum" +Word "variant" +Function "function" +Word "test_lambda_throws_enum" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "ValidationFailed" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Throwing" +Word "classes" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Throw" +Word "a" +Class "class" +Word "instance" +Function "function" +Word "throw_class_instance" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +Throws "throws" +Word "ApiError" +LBrace "{" +If "if" +LParen "(" +Word "x" +Less "<" +IntegerLiteral "0" +RParen ")" +LBrace "{" +Throw "throw" +Word "ApiError" +LBrace "{" +Word "code" +Colon ":" +IntegerLiteral "500" +Comma "," +Word "message" +Colon ":" +Quote "\"" +Word "internal" +Word "error" +Quote "\"" +RBrace "}" +RBrace "}" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +Slash "/" +Slash "/" +Word "Lambda" +Word "that" +Throws "throws" +Word "a" +Class "class" +Function "function" +Word "test_lambda_throws_class" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Word "ValidationError" +LBrace "{" +Word "field" +Colon ":" +Quote "\"" +Word "email" +Quote "\"" +Comma "," +Word "reason" +Colon ":" +Quote "\"" +Word "invalid" +Word "format" +Quote "\"" +RBrace "}" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Union" +Throws "throws" +Word "types" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Union" +Word "of" +Enum "enum" +Word "variants" +LParen "(" +Word "single" +Enum "enum" +Word "type" +Word "covers" +Word "all" +RParen ")" +Function "function" +Word "throw_various_errors" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +Throws "throws" +Word "ErrorKind" +LBrace "{" +Match "match" +LParen "(" +Word "x" +RParen ")" +LBrace "{" +IntegerLiteral "0" +FatArrow "=>" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "NotFound" +Comma "," +IntegerLiteral "1" +FatArrow "=>" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "Unauthorized" +Comma "," +IntegerLiteral "2" +FatArrow "=>" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "ValidationFailed" +Comma "," +Word "_" +FatArrow "=>" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +Word "Union" +Word "of" +Word "different" +Word "classes" +Function "function" +Word "throw_mixed_classes" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +Throws "throws" +Word "ApiError" +Pipe "|" +Word "ValidationError" +LBrace "{" +If "if" +LParen "(" +Word "x" +EqualsEquals "==" +IntegerLiteral "0" +RParen ")" +LBrace "{" +Throw "throw" +Word "ApiError" +LBrace "{" +Word "code" +Colon ":" +IntegerLiteral "404" +Comma "," +Word "message" +Colon ":" +Quote "\"" +Word "not" +Word "found" +Quote "\"" +RBrace "}" +RBrace "}" +If "if" +LParen "(" +Word "x" +EqualsEquals "==" +IntegerLiteral "1" +RParen ")" +LBrace "{" +Throw "throw" +Word "ValidationError" +LBrace "{" +Word "field" +Colon ":" +Quote "\"" +Word "id" +Quote "\"" +Comma "," +Word "reason" +Colon ":" +Quote "\"" +Word "required" +Quote "\"" +RBrace "}" +RBrace "}" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +Slash "/" +Slash "/" +Word "Union" +Word "of" +Enum "enum" +Word "and" +Class "class" +Function "function" +Word "throw_enum_or_class" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +Throws "throws" +Word "ErrorKind" +Pipe "|" +Word "ApiError" +LBrace "{" +If "if" +LParen "(" +Word "x" +EqualsEquals "==" +IntegerLiteral "0" +RParen ")" +LBrace "{" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "RateLimited" +RBrace "}" +If "if" +LParen "(" +Word "x" +EqualsEquals "==" +IntegerLiteral "1" +RParen ")" +LBrace "{" +Throw "throw" +Word "ApiError" +LBrace "{" +Word "code" +Colon ":" +IntegerLiteral "503" +Comma "," +Word "message" +Colon ":" +Quote "\"" +Word "unavailable" +Quote "\"" +RBrace "}" +RBrace "}" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +Slash "/" +Slash "/" +Word "Union" +Word "of" +Word "primitive" +Comma "," +Enum "enum" +Comma "," +Word "and" +Class "class" +Function "function" +Word "throw_any_error" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +Throws "throws" +Word "string" +Pipe "|" +Word "ErrorKind" +Pipe "|" +Word "ApiError" +LBrace "{" +Match "match" +LParen "(" +Word "x" +RParen ")" +LBrace "{" +IntegerLiteral "0" +FatArrow "=>" +Throw "throw" +Quote "\"" +Word "simple" +Word "string" +Word "error" +Quote "\"" +Comma "," +IntegerLiteral "1" +FatArrow "=>" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "NotFound" +Comma "," +IntegerLiteral "2" +FatArrow "=>" +Throw "throw" +Word "ApiError" +LBrace "{" +Word "code" +Colon ":" +IntegerLiteral "400" +Comma "," +Word "message" +Colon ":" +Quote "\"" +Word "bad" +Word "request" +Quote "\"" +RBrace "}" +Comma "," +Word "_" +FatArrow "=>" +Quote "\"" +Word "ok" +Quote "\"" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "HOF" +Word "with" +Enum "enum" +Slash "/" +Class "class" +Throws "throws" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "HOF" +Word "with" +Word "callback" +Word "that" +Throws "throws" +Enum "enum" +Function "function" +Word "apply_may_throw_enum" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "ErrorKind" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_apply_enum_thrower" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_may_throw_enum" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "Unauthorized" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "HOF" +Word "with" +Word "callback" +Word "that" +Throws "throws" +Class "class" +Function "function" +Word "apply_may_throw_class" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "ApiError" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_apply_class_thrower" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_may_throw_class" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Word "ApiError" +LBrace "{" +Word "code" +Colon ":" +IntegerLiteral "401" +Comma "," +Word "message" +Colon ":" +Quote "\"" +Word "unauthorized" +Quote "\"" +RBrace "}" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "HOF" +Word "with" +Word "implicit" +Word "rethrows" +Minus "-" +Enum "enum" +Word "callback" +Function "function" +Word "apply_generic_enum" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_rethrows_enum" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_generic_enum" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "NotFound" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "HOF" +Word "with" +Word "implicit" +Word "rethrows" +Minus "-" +Class "class" +Word "callback" +Function "function" +Word "apply_generic_class" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_rethrows_class" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_generic_class" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Word "ValidationError" +LBrace "{" +Word "field" +Colon ":" +Quote "\"" +Word "name" +Quote "\"" +Comma "," +Word "reason" +Colon ":" +Quote "\"" +Word "too" +Word "short" +Quote "\"" +RBrace "}" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Catch" +Word "patterns" +Word "with" +Word "enums" +Slash "/" +Word "classes" +EqualsEquals "==" +Equals "=" +Function "function" +Word "catch_enum_variants" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "throw_enum_variant" +LParen "(" +Word "x" +RParen ")" +Catch "catch" +LParen "(" +Word "e" +RParen ")" +LBrace "{" +Word "ErrorKind" +Dot "." +Word "NotFound" +FatArrow "=>" +Quote "\"" +Word "not" +Word "found" +Quote "\"" +Comma "," +Word "ErrorKind" +Dot "." +Word "Unauthorized" +FatArrow "=>" +Quote "\"" +Word "unauthorized" +Quote "\"" +Comma "," +Word "_" +FatArrow "=>" +Quote "\"" +Word "other" +Word "error" +Quote "\"" +RBrace "}" +RBrace "}" +Function "function" +Word "catch_class_error" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "throw_class_instance" +LParen "(" +Word "x" +RParen ")" +Catch "catch" +LParen "(" +Word "e" +RParen ")" +LBrace "{" +Word "_" +Colon ":" +Word "ApiError" +FatArrow "=>" +Quote "\"" +Word "api" +Word "error" +Colon ":" +Quote "\"" +Plus "+" +Word "e" +Dot "." +Word "message" +RBrace "}" +RBrace "}" +Function "function" +Word "catch_mixed_errors" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "throw_enum_or_class" +LParen "(" +Word "x" +RParen ")" +Catch "catch" +LParen "(" +Word "e" +RParen ")" +LBrace "{" +Word "ErrorKind" +Dot "." +Word "RateLimited" +FatArrow "=>" +Quote "\"" +Word "rate" +Word "limited" +Quote "\"" +Comma "," +Word "_" +Colon ":" +Word "ApiError" +FatArrow "=>" +Quote "\"" +Word "api" +Word "error" +Quote "\"" +Comma "," +Word "_" +FatArrow "=>" +Quote "\"" +Word "unknown" +Quote "\"" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Stored" +Word "callbacks" +Word "with" +Enum "enum" +Slash "/" +Class "class" +Throws "throws" +EqualsEquals "==" +Equals "=" +Class "class" +Word "EnumThrowingHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Throws "throws" +Word "ErrorKind" +RBrace "}" +Class "class" +Word "ClassThrowingHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Throws "throws" +Word "ApiError" +RBrace "}" +Class "class" +Word "MixedThrowingHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Throws "throws" +Word "ErrorKind" +Pipe "|" +Word "ApiError" +RBrace "}" +Function "function" +Word "make_enum_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "EnumThrowingHandler" +LBrace "{" +Word "EnumThrowingHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "NotFound" +RBrace "}" +RBrace "}" +RBrace "}" +Function "function" +Word "make_class_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "ClassThrowingHandler" +LBrace "{" +Word "ClassThrowingHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Word "ApiError" +LBrace "{" +Word "code" +Colon ":" +IntegerLiteral "500" +Comma "," +Word "message" +Colon ":" +Quote "\"" +Word "fail" +Quote "\"" +RBrace "}" +RBrace "}" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Type" +Word "aliases" +Word "with" +Enum "enum" +Slash "/" +Class "class" +Throws "throws" +EqualsEquals "==" +Equals "=" +Word "type" +Word "EnumThrower" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "ErrorKind" +Word "type" +Word "ClassThrower" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "ApiError" +Word "type" +Word "MixedThrower" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "ErrorKind" +Pipe "|" +Word "ApiError" +Pipe "|" +Word "string" +Function "function" +Word "use_enum_thrower" +LParen "(" +Word "f" +Colon ":" +Word "EnumThrower" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "use_class_thrower" +LParen "(" +Word "f" +Colon ":" +Word "ClassThrower" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "use_mixed_thrower" +LParen "(" +Word "f" +Colon ":" +Word "MixedThrower" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_type_alias_enum" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "use_enum_thrower" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Word "ErrorKind" +Dot "." +Word "ValidationFailed" +RBrace "}" +RParen ")" +RBrace "}" +Function "function" +Word "test_type_alias_class" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "use_class_thrower" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Word "ApiError" +LBrace "{" +Word "code" +Colon ":" +IntegerLiteral "422" +Comma "," +Word "message" +Colon ":" +Quote "\"" +Word "invalid" +Quote "\"" +RBrace "}" +RBrace "}" +RParen ")" +RBrace "}" +Function "function" +Word "test_type_alias_mixed" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "use_mixed_thrower" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "string" +Word "error" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__compose_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__compose_hof.snap new file mode 100644 index 0000000000..732d5c9073 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__compose_hof.snap @@ -0,0 +1,358 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "compose" + GENERIC_PARAM_LIST + LESS "<" + GENERIC_PARAM "A" + WORD "A" + COMMA "," + GENERIC_PARAM "B" + WORD "B" + COMMA "," + GENERIC_PARAM "C" + WORD "C" + GREATER ">" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "A" + WORD "A" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "B" + WORD "B" + COMMA "," + PARAMETER + WORD "g" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "B" + WORD "B" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "C" + WORD "C" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "A" + WORD "A" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "C" + WORD "C" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "a" + COLON ":" + TYPE_EXPR "A" + WORD "A" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "C" + WORD "C" + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "g" + CALL_ARGS + L_PAREN "(" + CALL_EXPR + WORD "f" + CALL_ARGS "(a)" + L_PAREN "(" + WORD "a" + R_PAREN ")" + R_PAREN ")" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_compose_pure" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "compose" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + BINARY_EXPR "x * 2" + WORD "x" + STAR "*" + INTEGER_LITERAL "2" + R_BRACE "}" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "y" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR + L_BRACE "{" + STRING_LITERAL "result" + QUOTE """ + WORD "result" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_compose_first_throws" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "compose" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "f failed" + QUOTE """ + WORD "f" + WORD "failed" + QUOTE """ + R_BRACE "}" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "y" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR + L_BRACE "{" + STRING_LITERAL "result" + QUOTE """ + WORD "result" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_compose_second_throws" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "compose" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + BINARY_EXPR "x * 2" + WORD "x" + STAR "*" + INTEGER_LITERAL "2" + R_BRACE "}" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "y" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "g failed" + QUOTE """ + WORD "g" + WORD "failed" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_compose_both_throw" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "compose" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "string error" + QUOTE """ + WORD "string" + WORD "error" + QUOTE """ + R_BRACE "}" + COMMA "," + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "y" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__enum_class_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__enum_class_throws.snap new file mode 100644 index 0000000000..873ce7c2c1 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__enum_class_throws.snap @@ -0,0 +1,1440 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + ENUM_DEF + KW_ENUM "enum" + WORD "ErrorKind" + L_BRACE "{" + ENUM_VARIANT "NotFound" + WORD "NotFound" + ENUM_VARIANT "Unauthorized" + WORD "Unauthorized" + ENUM_VARIANT "ValidationFailed" + WORD "ValidationFailed" + ENUM_VARIANT "RateLimited" + WORD "RateLimited" + R_BRACE "}" + CLASS_DEF + KW_CLASS "class" + WORD "ApiError" + L_BRACE "{" + FIELD + WORD "code" + COLON ":" + TYPE_EXPR "int" + WORD "int" + FIELD + WORD "message" + COLON ":" + TYPE_EXPR "string" + WORD "string" + R_BRACE "}" + CLASS_DEF + KW_CLASS "class" + WORD "ValidationError" + L_BRACE "{" + FIELD + WORD "field" + COLON ":" + TYPE_EXPR "string" + WORD "string" + FIELD + WORD "reason" + COLON ":" + TYPE_EXPR "string" + WORD "string" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "throw_enum_variant" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ErrorKind" + WORD "ErrorKind" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "x == 0" + WORD "x" + EQUALS_EQUALS "==" + INTEGER_LITERAL "0" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.NotFound" + WORD "ErrorKind" + DOT "." + WORD "NotFound" + R_BRACE "}" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "x == 1" + WORD "x" + EQUALS_EQUALS "==" + INTEGER_LITERAL "1" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.Unauthorized" + WORD "ErrorKind" + DOT "." + WORD "Unauthorized" + R_BRACE "}" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_lambda_throws_enum" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.ValidationFailed" + WORD "ErrorKind" + DOT "." + WORD "ValidationFailed" + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "throw_class_instance" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ApiError" + WORD "ApiError" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "x < 0" + WORD "x" + LESS "<" + INTEGER_LITERAL "0" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + OBJECT_LITERAL + WORD "ApiError" + L_BRACE "{" + OBJECT_FIELD "code: 500" + WORD "code" + COLON ":" + INTEGER_LITERAL "500" + COMMA "," + OBJECT_FIELD + WORD "message" + COLON ":" + STRING_LITERAL "internal error" + QUOTE """ + WORD "internal" + WORD "error" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_lambda_throws_class" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + OBJECT_LITERAL + WORD "ValidationError" + L_BRACE "{" + OBJECT_FIELD + WORD "field" + COLON ":" + STRING_LITERAL "email" + QUOTE """ + WORD "email" + QUOTE """ + COMMA "," + OBJECT_FIELD + WORD "reason" + COLON ":" + STRING_LITERAL "invalid format" + QUOTE """ + WORD "invalid" + WORD "format" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "throw_various_errors" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ErrorKind" + WORD "ErrorKind" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + MATCH_EXPR + KW_MATCH "match" + L_PAREN "(" + WORD "x" + R_PAREN ")" + L_BRACE "{" + MATCH_ARM + MATCH_PATTERN "0" + INTEGER_LITERAL "0" + FAT_ARROW "=>" + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.NotFound" + WORD "ErrorKind" + DOT "." + WORD "NotFound" + COMMA "," + MATCH_ARM + MATCH_PATTERN "1" + INTEGER_LITERAL "1" + FAT_ARROW "=>" + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.Unauthorized" + WORD "ErrorKind" + DOT "." + WORD "Unauthorized" + COMMA "," + MATCH_ARM + MATCH_PATTERN "2" + INTEGER_LITERAL "2" + FAT_ARROW "=>" + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.ValidationFailed" + WORD "ErrorKind" + DOT "." + WORD "ValidationFailed" + COMMA "," + MATCH_ARM + MATCH_PATTERN "_" + WORD "_" + FAT_ARROW "=>" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "throw_mixed_classes" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ApiError | ValidationError" + WORD "ApiError" + PIPE "|" + WORD "ValidationError" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "x == 0" + WORD "x" + EQUALS_EQUALS "==" + INTEGER_LITERAL "0" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + OBJECT_LITERAL + WORD "ApiError" + L_BRACE "{" + OBJECT_FIELD "code: 404" + WORD "code" + COLON ":" + INTEGER_LITERAL "404" + COMMA "," + OBJECT_FIELD + WORD "message" + COLON ":" + STRING_LITERAL "not found" + QUOTE """ + WORD "not" + WORD "found" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "x == 1" + WORD "x" + EQUALS_EQUALS "==" + INTEGER_LITERAL "1" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + OBJECT_LITERAL + WORD "ValidationError" + L_BRACE "{" + OBJECT_FIELD + WORD "field" + COLON ":" + STRING_LITERAL "id" + QUOTE """ + WORD "id" + QUOTE """ + COMMA "," + OBJECT_FIELD + WORD "reason" + COLON ":" + STRING_LITERAL "required" + QUOTE """ + WORD "required" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "throw_enum_or_class" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ErrorKind | ApiError" + WORD "ErrorKind" + PIPE "|" + WORD "ApiError" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "x == 0" + WORD "x" + EQUALS_EQUALS "==" + INTEGER_LITERAL "0" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.RateLimited" + WORD "ErrorKind" + DOT "." + WORD "RateLimited" + R_BRACE "}" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "x == 1" + WORD "x" + EQUALS_EQUALS "==" + INTEGER_LITERAL "1" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + OBJECT_LITERAL + WORD "ApiError" + L_BRACE "{" + OBJECT_FIELD "code: 503" + WORD "code" + COLON ":" + INTEGER_LITERAL "503" + COMMA "," + OBJECT_FIELD + WORD "message" + COLON ":" + STRING_LITERAL "unavailable" + QUOTE """ + WORD "unavailable" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "throw_any_error" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string | ErrorKind | ApiError" + WORD "string" + PIPE "|" + WORD "ErrorKind" + PIPE "|" + WORD "ApiError" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + MATCH_EXPR + KW_MATCH "match" + L_PAREN "(" + WORD "x" + R_PAREN ")" + L_BRACE "{" + MATCH_ARM + MATCH_PATTERN "0" + INTEGER_LITERAL "0" + FAT_ARROW "=>" + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "simple string error" + QUOTE """ + WORD "simple" + WORD "string" + WORD "error" + QUOTE """ + COMMA "," + MATCH_ARM + MATCH_PATTERN "1" + INTEGER_LITERAL "1" + FAT_ARROW "=>" + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.NotFound" + WORD "ErrorKind" + DOT "." + WORD "NotFound" + COMMA "," + MATCH_ARM + MATCH_PATTERN "2" + INTEGER_LITERAL "2" + FAT_ARROW "=>" + THROW_EXPR + KW_THROW "throw" + OBJECT_LITERAL + WORD "ApiError" + L_BRACE "{" + OBJECT_FIELD "code: 400" + WORD "code" + COLON ":" + INTEGER_LITERAL "400" + COMMA "," + OBJECT_FIELD + WORD "message" + COLON ":" + STRING_LITERAL "bad request" + QUOTE """ + WORD "bad" + WORD "request" + QUOTE """ + R_BRACE "}" + COMMA "," + MATCH_ARM + MATCH_PATTERN "_" + WORD "_" + FAT_ARROW "=>" + STRING_LITERAL "ok" + QUOTE """ + WORD "ok" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_may_throw_enum" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ErrorKind" + WORD "ErrorKind" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_apply_enum_thrower" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_may_throw_enum" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.Unauthorized" + WORD "ErrorKind" + DOT "." + WORD "Unauthorized" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_may_throw_class" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ApiError" + WORD "ApiError" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_apply_class_thrower" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_may_throw_class" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + OBJECT_LITERAL + WORD "ApiError" + L_BRACE "{" + OBJECT_FIELD "code: 401" + WORD "code" + COLON ":" + INTEGER_LITERAL "401" + COMMA "," + OBJECT_FIELD + WORD "message" + COLON ":" + STRING_LITERAL "unauthorized" + QUOTE """ + WORD "unauthorized" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_generic_enum" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_rethrows_enum" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_generic_enum" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.NotFound" + WORD "ErrorKind" + DOT "." + WORD "NotFound" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_generic_class" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_rethrows_class" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_generic_class" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + OBJECT_LITERAL + WORD "ValidationError" + L_BRACE "{" + OBJECT_FIELD + WORD "field" + COLON ":" + STRING_LITERAL "name" + QUOTE """ + WORD "name" + QUOTE """ + COMMA "," + OBJECT_FIELD + WORD "reason" + COLON ":" + STRING_LITERAL "too short" + QUOTE """ + WORD "too" + WORD "short" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "catch_enum_variants" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CATCH_EXPR + CALL_EXPR + WORD "throw_enum_variant" + CALL_ARGS "(x)" + L_PAREN "(" + WORD "x" + R_PAREN ")" + CATCH_CLAUSE + KW_CATCH "catch" + L_PAREN "(" + CATCH_PATTERN "e" + WORD "e" + R_PAREN ")" + L_BRACE "{" + CATCH_ARM + CATCH_PATTERN "ErrorKind.NotFound" + WORD "ErrorKind" + DOT "." + WORD "NotFound" + FAT_ARROW "=>" + STRING_LITERAL "not found" + QUOTE """ + WORD "not" + WORD "found" + QUOTE """ + COMMA "," + CATCH_ARM + CATCH_PATTERN "ErrorKind.Unauthorized" + WORD "ErrorKind" + DOT "." + WORD "Unauthorized" + FAT_ARROW "=>" + STRING_LITERAL "unauthorized" + QUOTE """ + WORD "unauthorized" + QUOTE """ + COMMA "," + CATCH_ARM + CATCH_PATTERN "_" + WORD "_" + FAT_ARROW "=>" + STRING_LITERAL "other error" + QUOTE """ + WORD "other" + WORD "error" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "catch_class_error" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CATCH_EXPR + CALL_EXPR + WORD "throw_class_instance" + CALL_ARGS "(x)" + L_PAREN "(" + WORD "x" + R_PAREN ")" + CATCH_CLAUSE + KW_CATCH "catch" + L_PAREN "(" + CATCH_PATTERN "e" + WORD "e" + R_PAREN ")" + L_BRACE "{" + CATCH_ARM + CATCH_PATTERN + WORD "_" + COLON ":" + TYPE_EXPR "ApiError" + WORD "ApiError" + FAT_ARROW "=>" + BINARY_EXPR + STRING_LITERAL "api error: " + QUOTE """ + WORD "api" + WORD "error" + COLON ":" + QUOTE """ + PLUS "+" + PATH_EXPR "e.message" + WORD "e" + DOT "." + WORD "message" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "catch_mixed_errors" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CATCH_EXPR + CALL_EXPR + WORD "throw_enum_or_class" + CALL_ARGS "(x)" + L_PAREN "(" + WORD "x" + R_PAREN ")" + CATCH_CLAUSE + KW_CATCH "catch" + L_PAREN "(" + CATCH_PATTERN "e" + WORD "e" + R_PAREN ")" + L_BRACE "{" + CATCH_ARM + CATCH_PATTERN "ErrorKind.RateLimited" + WORD "ErrorKind" + DOT "." + WORD "RateLimited" + FAT_ARROW "=>" + STRING_LITERAL "rate limited" + QUOTE """ + WORD "rate" + WORD "limited" + QUOTE """ + COMMA "," + CATCH_ARM + CATCH_PATTERN + WORD "_" + COLON ":" + TYPE_EXPR "ApiError" + WORD "ApiError" + FAT_ARROW "=>" + STRING_LITERAL "api error" + QUOTE """ + WORD "api" + WORD "error" + QUOTE """ + COMMA "," + CATCH_ARM + CATCH_PATTERN "_" + WORD "_" + FAT_ARROW "=>" + STRING_LITERAL "unknown" + QUOTE """ + WORD "unknown" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + CLASS_DEF + KW_CLASS "class" + WORD "EnumThrowingHandler" + L_BRACE "{" + FIELD + WORD "run" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ErrorKind" + WORD "ErrorKind" + R_BRACE "}" + CLASS_DEF + KW_CLASS "class" + WORD "ClassThrowingHandler" + L_BRACE "{" + FIELD + WORD "run" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ApiError" + WORD "ApiError" + R_BRACE "}" + CLASS_DEF + KW_CLASS "class" + WORD "MixedThrowingHandler" + L_BRACE "{" + FIELD + WORD "run" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ErrorKind | ApiError" + WORD "ErrorKind" + PIPE "|" + WORD "ApiError" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_enum_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "EnumThrowingHandler" + WORD "EnumThrowingHandler" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OBJECT_LITERAL + WORD "EnumThrowingHandler" + L_BRACE "{" + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.NotFound" + WORD "ErrorKind" + DOT "." + WORD "NotFound" + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_class_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "ClassThrowingHandler" + WORD "ClassThrowingHandler" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OBJECT_LITERAL + WORD "ClassThrowingHandler" + L_BRACE "{" + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + OBJECT_LITERAL + WORD "ApiError" + L_BRACE "{" + OBJECT_FIELD "code: 500" + WORD "code" + COLON ":" + INTEGER_LITERAL "500" + COMMA "," + OBJECT_FIELD + WORD "message" + COLON ":" + STRING_LITERAL "fail" + QUOTE """ + WORD "fail" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + TYPE_ALIAS_DEF + WORD "type" + WORD "EnumThrower" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ErrorKind" + WORD "ErrorKind" + TYPE_ALIAS_DEF + WORD "type" + WORD "ClassThrower" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ApiError" + WORD "ApiError" + TYPE_ALIAS_DEF + WORD "type" + WORD "MixedThrower" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "ErrorKind | ApiError | string" + WORD "ErrorKind" + PIPE "|" + WORD "ApiError" + PIPE "|" + WORD "string" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "use_enum_thrower" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR "EnumThrower" + WORD "EnumThrower" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "use_class_thrower" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR "ClassThrower" + WORD "ClassThrower" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "use_mixed_thrower" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR "MixedThrower" + WORD "MixedThrower" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_type_alias_enum" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "use_enum_thrower" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + PATH_EXPR "ErrorKind.ValidationFailed" + WORD "ErrorKind" + DOT "." + WORD "ValidationFailed" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_type_alias_class" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "use_class_thrower" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + OBJECT_LITERAL + WORD "ApiError" + L_BRACE "{" + OBJECT_FIELD "code: 422" + WORD "code" + COLON ":" + INTEGER_LITERAL "422" + COMMA "," + OBJECT_FIELD + WORD "message" + COLON ":" + STRING_LITERAL "invalid" + QUOTE """ + WORD "invalid" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_type_alias_mixed" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "use_mixed_thrower" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "string error" + QUOTE """ + WORD "string" + WORD "error" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index ceb16a5b32..9f0d8a87da 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -55,6 +55,123 @@ function user.make_pure_handler() -> user.PureHandler [expr] { function user.make_throwing_handler() -> user.ThrowingHandler [expr] { { } user.ThrowingHandler { run: () -> null { { throw "error" } } } } +function user.compose(f: (user.A) -> user.B, g: (user.B) -> user.C) -> (user.A) -> user.C [expr] { + { } +} +function user.test_compose_both_throw() -> (int) -> string [expr] { + { } compose((x: int) -> int { { throw "string error" } }, (y: int) -> string { { throw 42 } }) +} +function user.test_compose_first_throws() -> (int) -> string [expr] { + { } compose((x: int) -> int { { throw "f failed" } }, (y: int) -> string { { } "result" }) +} +function user.test_compose_pure() -> (int) -> string [expr] { + { } compose((x: int) -> int { { } x Mul 2 }, (y: int) -> string { { } "result" }) +} +function user.test_compose_second_throws() -> (int) -> string [expr] { + { } compose((x: int) -> int { { } x Mul 2 }, (y: int) -> string { { throw "g failed" } }) +} +class user.ApiError { + code: int + message: string +} +class user.ClassThrowingHandler { + run: () -> null +} +class user.EnumThrowingHandler { + run: () -> null +} +class user.MixedThrowingHandler { + run: () -> null +} +class user.ValidationError { + field: string + reason: string +} +enum user.ErrorKind {NotFound, Unauthorized, ValidationFailed, RateLimited} +type user.ClassThrower = () -> int +type user.EnumThrower = () -> int +type user.MixedThrower = () -> int +function user.apply_generic_class(f: () -> int) -> int [expr] { + { } f() +} +function user.apply_generic_enum(f: () -> int) -> int [expr] { + { } f() +} +function user.apply_may_throw_class(f: () -> int) -> int [expr] { + { } f() +} +function user.apply_may_throw_enum(f: () -> int) -> int [expr] { + { } f() +} +function user.catch_class_error(x: int) -> string [expr] { + { } throw_class_instance(x) catch (e) { _: user.ApiError => "api error: " Add e.message } +} +function user.catch_enum_variants(x: int) -> string [expr] { + { } throw_enum_variant(x) catch (e) { user.ErrorKind.NotFound => "not found", user.ErrorKind.Unauthorized => "unauthorized", _ => "other error" } +} +function user.catch_mixed_errors(x: int) -> string [expr] { + { } throw_enum_or_class(x) catch (e) { user.ErrorKind.RateLimited => "rate limited", _: user.ApiError => "api error", _ => "unknown" } +} +function user.make_class_handler() -> user.ClassThrowingHandler [expr] { + { } user.ClassThrowingHandler { run: () -> null { { throw user.ApiError { code: 500, message: "fail" } } } } +} +function user.make_enum_handler() -> user.EnumThrowingHandler [expr] { + { } user.EnumThrowingHandler { run: () -> null { { throw ErrorKind.NotFound } } } +} +function user.test_apply_class_thrower() -> int [expr] { + { } apply_may_throw_class(() -> int { { throw user.ApiError { code: 401, message: "unauthorized" } } }) +} +function user.test_apply_enum_thrower() -> int [expr] { + { } apply_may_throw_enum(() -> int { { throw ErrorKind.Unauthorized } }) +} +function user.test_lambda_throws_class() -> int [expr] { + { let f = () -> int { { throw user.ValidationError { field: "email", reason: "invalid format" } } } } f() +} +function user.test_lambda_throws_enum() -> int [expr] { + { let f = () -> int { { throw ErrorKind.ValidationFailed } } } f() +} +function user.test_rethrows_class() -> int [expr] { + { } apply_generic_class(() -> int { { throw user.ValidationError { field: "name", reason: "too short" } } }) +} +function user.test_rethrows_enum() -> int [expr] { + { } apply_generic_enum(() -> int { { throw ErrorKind.NotFound } }) +} +function user.test_type_alias_class() -> int [expr] { + { } use_class_thrower(() -> int { { throw user.ApiError { code: 422, message: "invalid" } } }) +} +function user.test_type_alias_enum() -> int [expr] { + { } use_enum_thrower(() -> int { { throw ErrorKind.ValidationFailed } }) +} +function user.test_type_alias_mixed() -> int [expr] { + { } use_mixed_thrower(() -> int { { throw "string error" } }) +} +function user.throw_any_error(x: int) -> string [expr] { + { } match (x) { 0 => throw "simple string error", 1 => throw ErrorKind.NotFound, 2 => throw user.ApiError { code: 400, message: "bad request" }, _ => "ok" } +} +function user.throw_class_instance(x: int) -> string [expr] { + { if (x Lt 0) { throw user.ApiError { code: 500, message: "internal error" } } } "ok" +} +function user.throw_enum_or_class(x: int) -> string [expr] { + { if (x Eq 0) { throw ErrorKind.RateLimited }; if (x Eq 1) { throw user.ApiError { code: 503, message: "unavailable" } } } "ok" +} +function user.throw_enum_variant(x: int) -> string [expr] { + { if (x Eq 0) { throw ErrorKind.NotFound }; if (x Eq 1) { throw ErrorKind.Unauthorized } } "ok" +} +function user.throw_mixed_classes(x: int) -> string [expr] { + { if (x Eq 0) { throw user.ApiError { code: 404, message: "not found" } }; if (x Eq 1) { throw user.ValidationError { field: "id", reason: "required" } } } "ok" +} +function user.throw_various_errors(x: int) -> string [expr] { + { } match (x) { 0 => throw ErrorKind.NotFound, 1 => throw ErrorKind.Unauthorized, 2 => throw ErrorKind.ValidationFailed, _ => "ok" } +} +function user.use_class_thrower(f: user.ClassThrower) -> int [expr] { + { } f() +} +function user.use_enum_thrower(f: user.EnumThrower) -> int [expr] { + { } f() +} +function user.use_mixed_thrower(f: user.MixedThrower) -> int [expr] { + { } f() +} function user.apply_explicit(f: () -> int) -> int [expr] { { } f() } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index e138d7e4f8..bd1a35c0b1 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -420,6 +420,1063 @@ fn .() -> null { } } +fn user.compose(f: (void) -> void, g: (void) -> void) -> (void) -> void { + // Locals: + let _0: (void) -> void // _0 // return + let _1: (void) -> void // f // param + let _2: (void) -> void // g // param + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_compose_both_throw() -> (int) -> string { + // Locals: + let _0: (int) -> string // _0 // return + let _1: (int) -> void + let _2: (int) -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = make_closure lambda[1](); + _0 = call const fn user.compose(copy _1, copy _2) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(x: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // x // param + + bb0: { + throw const "string error"; + } +} + +// lambda[1] +fn .(y: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // y // param + + bb0: { + throw const 42_i64; + } +} + +fn user.test_compose_first_throws() -> (int) -> string { + // Locals: + let _0: (int) -> string // _0 // return + let _1: (int) -> void + let _2: (int) -> "result" + + bb0: { + _1 = make_closure lambda[0](); + _2 = make_closure lambda[1](); + _0 = call const fn user.compose(copy _1, copy _2) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(x: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // x // param + + bb0: { + throw const "f failed"; + } +} + +// lambda[1] +fn .(y: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // y // param + + bb0: { + _0 = const "result"; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_compose_pure() -> (int) -> string { + // Locals: + let _0: (int) -> string // _0 // return + let _1: (int) -> int + let _2: (int) -> "result" + + bb0: { + _1 = make_closure lambda[0](); + _2 = make_closure lambda[1](); + _0 = call const fn user.compose(copy _1, copy _2) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(x: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // x // param + + bb0: { + _0 = copy _1 * const 2_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[1] +fn .(y: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // y // param + + bb0: { + _0 = const "result"; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.test_compose_second_throws() -> (int) -> string { + // Locals: + let _0: (int) -> string // _0 // return + let _1: (int) -> int + let _2: (int) -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = make_closure lambda[1](); + _0 = call const fn user.compose(copy _1, copy _2) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(x: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // x // param + + bb0: { + _0 = copy _1 * const 2_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[1] +fn .(y: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // y // param + + bb0: { + throw const "g failed"; + } +} + +fn user.apply_generic_class(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.apply_generic_enum(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.apply_may_throw_class(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.apply_may_throw_enum(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.catch_class_error(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: unknown // e + let _3: bool + let _4: string + let _5: ApiError + + bb0: { + _0 = call const fn user.throw_class_instance(copy _1) -> [bb1, unwind: bb2]; + } + + bb1: { + goto -> bb5; + } + + bb2: { + _3 = is_type(copy _2, Class(TypeName { name: "ApiError", module_path: ["user"], display_name: "ApiError" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] })); + branch copy _3 -> [bb4, bb3]; + } + + bb3: { + throw copy _2; + } + + bb4: { + _5 = copy _2; + _4 = copy _5.1; + _0 = const "api error: " + copy _4; + goto -> bb5; + } + + bb5: { + goto -> bb6; + } + + bb6: { + return; + } +} + +fn user.catch_enum_variants(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: unknown // e + let _3: bool + let _4: bool + + bb0: { + _0 = call const fn user.throw_enum_variant(copy _1) -> [bb1, unwind: bb2]; + } + + bb1: { + goto -> bb9; + } + + bb2: { + _3 = copy _2 == const user.ErrorKind.NotFound; + branch copy _3 -> [bb8, bb3]; + } + + bb3: { + _4 = copy _2 == const user.ErrorKind.Unauthorized; + branch copy _4 -> [bb7, bb4]; + } + + bb4: { + throw_if_panic copy _2 -> bb5; + } + + bb5: { + goto -> bb6; + } + + bb6: { + _0 = const "other error"; + goto -> bb9; + } + + bb7: { + _0 = const "unauthorized"; + goto -> bb9; + } + + bb8: { + _0 = const "not found"; + goto -> bb9; + } + + bb9: { + goto -> bb10; + } + + bb10: { + return; + } +} + +fn user.catch_mixed_errors(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: unknown // e + let _3: bool + let _4: bool + + bb0: { + _0 = call const fn user.throw_enum_or_class(copy _1) -> [bb1, unwind: bb2]; + } + + bb1: { + goto -> bb9; + } + + bb2: { + _3 = copy _2 == const user.ErrorKind.RateLimited; + branch copy _3 -> [bb8, bb3]; + } + + bb3: { + _4 = is_type(copy _2, Class(TypeName { name: "ApiError", module_path: ["user"], display_name: "ApiError" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] })); + branch copy _4 -> [bb7, bb4]; + } + + bb4: { + throw_if_panic copy _2 -> bb5; + } + + bb5: { + goto -> bb6; + } + + bb6: { + _0 = const "unknown"; + goto -> bb9; + } + + bb7: { + _0 = const "api error"; + goto -> bb9; + } + + bb8: { + _0 = const "rate limited"; + goto -> bb9; + } + + bb9: { + goto -> bb10; + } + + bb10: { + return; + } +} + +fn user.make_class_handler() -> ClassThrowingHandler { + // Locals: + let _0: ClassThrowingHandler // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = ClassThrowingHandler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + let _1: ApiError + + bb0: { + _1 = ApiError { const 500_i64, const "fail" }; + throw copy _1; + } +} + +fn user.make_enum_handler() -> EnumThrowingHandler { + // Locals: + let _0: EnumThrowingHandler // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = EnumThrowingHandler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const user.ErrorKind.NotFound; + } +} + +fn user.test_apply_class_thrower() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_may_throw_class(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + let _1: ApiError + + bb0: { + _1 = ApiError { const 401_i64, const "unauthorized" }; + throw copy _1; + } +} + +fn user.test_apply_enum_thrower() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_may_throw_enum(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const user.ErrorKind.Unauthorized; + } +} + +fn user.test_lambda_throws_class() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void // f + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + let _1: ValidationError + + bb0: { + _1 = ValidationError { const "email", const "invalid format" }; + throw copy _1; + } +} + +fn user.test_lambda_throws_enum() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void // f + let _2: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const user.ErrorKind.ValidationFailed; + } +} + +fn user.test_rethrows_class() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_generic_class(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + let _1: ValidationError + + bb0: { + _1 = ValidationError { const "name", const "too short" }; + throw copy _1; + } +} + +fn user.test_rethrows_enum() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_generic_enum(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const user.ErrorKind.NotFound; + } +} + +fn user.test_type_alias_class() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.use_class_thrower(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + let _1: ApiError + + bb0: { + _1 = ApiError { const 422_i64, const "invalid" }; + throw copy _1; + } +} + +fn user.test_type_alias_enum() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.use_enum_thrower(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const user.ErrorKind.ValidationFailed; + } +} + +fn user.test_type_alias_mixed() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.use_mixed_thrower(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "string error"; + } +} + +fn user.throw_any_error(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: ApiError + + bb0: { + switch copy _1 [0: bb6, 1: bb5, 2: bb4, otherwise: bb1]; + } + + bb1: { + _0 = const "ok"; + goto -> bb2; + } + + bb2: { + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + _2 = ApiError { const 400_i64, const "bad request" }; + throw copy _2; + } + + bb5: { + throw const user.ErrorKind.NotFound; + } + + bb6: { + throw const "simple string error"; + } +} + +fn user.throw_class_instance(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: bool + let _3: ApiError + + bb0: { + _2 = copy _1 < const 0_i64; + branch copy _2 -> [bb4, bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _0 = const "ok"; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + _3 = ApiError { const 500_i64, const "internal error" }; + throw copy _3; + } +} + +fn user.throw_enum_or_class(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: bool + let _3: bool + let _4: ApiError + + bb0: { + _2 = copy _1 == const 0_i64; + branch copy _2 -> [bb7, bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _3 = copy _1 == const 1_i64; + branch copy _3 -> [bb6, bb3]; + } + + bb3: { + goto -> bb4; + } + + bb4: { + _0 = const "ok"; + goto -> bb5; + } + + bb5: { + return; + } + + bb6: { + _4 = ApiError { const 503_i64, const "unavailable" }; + throw copy _4; + } + + bb7: { + throw const user.ErrorKind.RateLimited; + } +} + +fn user.throw_enum_variant(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: bool + let _3: bool + + bb0: { + _2 = copy _1 == const 0_i64; + branch copy _2 -> [bb7, bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _3 = copy _1 == const 1_i64; + branch copy _3 -> [bb6, bb3]; + } + + bb3: { + goto -> bb4; + } + + bb4: { + _0 = const "ok"; + goto -> bb5; + } + + bb5: { + return; + } + + bb6: { + throw const user.ErrorKind.Unauthorized; + } + + bb7: { + throw const user.ErrorKind.NotFound; + } +} + +fn user.throw_mixed_classes(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + let _2: bool + let _3: ApiError + let _4: bool + let _5: ValidationError + + bb0: { + _2 = copy _1 == const 0_i64; + branch copy _2 -> [bb7, bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _4 = copy _1 == const 1_i64; + branch copy _4 -> [bb6, bb3]; + } + + bb3: { + goto -> bb4; + } + + bb4: { + _0 = const "ok"; + goto -> bb5; + } + + bb5: { + return; + } + + bb6: { + _5 = ValidationError { const "id", const "required" }; + throw copy _5; + } + + bb7: { + _3 = ApiError { const 404_i64, const "not found" }; + throw copy _3; + } +} + +fn user.throw_various_errors(x: int) -> string { + // Locals: + let _0: string // _0 // return + let _1: int // x // param + + bb0: { + switch copy _1 [0: bb6, 1: bb5, 2: bb4, otherwise: bb1]; + } + + bb1: { + _0 = const "ok"; + goto -> bb2; + } + + bb2: { + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + throw const user.ErrorKind.ValidationFailed; + } + + bb5: { + throw const user.ErrorKind.Unauthorized; + } + + bb6: { + throw const user.ErrorKind.NotFound; + } +} + +fn user.use_class_thrower(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.use_enum_thrower(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.use_mixed_thrower(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + fn user.apply_explicit(f: () -> int) -> int { // Locals: let _0: int // _0 // return diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 917c8f768c..87e37ba59f 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -141,6 +141,387 @@ class user.MixedHandler$stream { safe: unknown risky: unknown } +function user.compose(f: (A) -> B throws __throws_f, g: (B) -> C throws __throws_g) -> (A) -> C throws __throws_f | __throws_g { + { : (A) -> C + } + !! 193..223: missing return: expected `(A) -> C` +} +function user.test_compose_pure() -> (int) -> string throws never { + { : (int) -> string | "result" + compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | "result" + (x: int) -> int { ... } : (x: int) -> int + { + x * 2 + } + (y: int) -> string { ... } : (y: int) -> "result" + { + "result" + } + } +} +lambda user.test_compose_pure { +} +lambda user.test_compose_pure { +} +function user.test_compose_first_throws() -> (int) -> string throws string { + { : (int) -> string | "result" + compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | "result" + (x: int) -> int { ... } : (x: int) -> never throws string + { + throw "f failed" + } + (y: int) -> string { ... } : (y: int) -> "result" + { + "result" + } + } +} +lambda user.test_compose_first_throws { +} +lambda user.test_compose_first_throws { +} +function user.test_compose_second_throws() -> (int) -> string throws string { + { : (int) -> string | never + compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | never + (x: int) -> int { ... } : (x: int) -> int + { + x * 2 + } + (y: int) -> string { ... } : (y: int) -> never throws string + { + throw "g failed" + } + } +} +lambda user.test_compose_second_throws { +} +lambda user.test_compose_second_throws { +} +function user.test_compose_both_throw() -> (int) -> string throws int | string { + { : (int) -> string | never + compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | never + (x: int) -> int { ... } : (x: int) -> never throws string + { + throw "string error" + } + (y: int) -> string { ... } : (y: int) -> never throws int + { + throw 42 + } + } +} +lambda user.test_compose_both_throw { +} +lambda user.test_compose_both_throw { +} +enum user.ErrorKind +class user.ApiError { + code: int + message: string +} +class user.ValidationError { + field: string + reason: string +} +function user.throw_enum_variant(x: int) -> string throws user.ErrorKind { + { : "ok" + if (x == 0 : bool) : void + { : never + throw ErrorKind.NotFound : user.ErrorKind.NotFound + } + if (x == 1 : bool) : void + { : never + throw ErrorKind.Unauthorized : user.ErrorKind.Unauthorized + } + "ok" : "ok" + } + !! 411..421: throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized + ?? 411..421: extraneous throws declaration: user.ErrorKind +} +function user.test_lambda_throws_enum() -> int throws user.ErrorKind.ValidationFailed { + { : never + let f = : () -> never throws user.ErrorKind.ValidationFailed + () -> int { ... } : () -> never throws user.ErrorKind.ValidationFailed + { + throw ErrorKind.ValidationFailed + } + f() : never + } +} +lambda user.test_lambda_throws_enum { +} +function user.throw_class_instance(x: int) -> string throws user.ApiError { + { : "ok" + if (x < 0 : bool) : void + { : never + throw ApiError { code: 500, message: "internal error" } : user.ApiError + } + "ok" : "ok" + } +} +function user.test_lambda_throws_class() -> int throws user.ValidationError { + { : never + let f = : () -> never throws user.ValidationError + () -> int { ... } : () -> never throws user.ValidationError + { + throw ValidationError { field: "email", reason: "invalid format" } + } + f() : never + } +} +lambda user.test_lambda_throws_class { +} +function user.throw_various_errors(x: int) -> string throws user.ErrorKind { + { : "ok" + match (x : int) : "ok" + 0 => + throw ErrorKind.NotFound : never + 1 => + throw ErrorKind.Unauthorized : never + 2 => + throw ErrorKind.ValidationFailed : never + _ => + "ok" : "ok" + } + !! 1204..1214: throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized, user.ErrorKind.ValidationFailed + ?? 1204..1214: extraneous throws declaration: user.ErrorKind +} +function user.throw_mixed_classes(x: int) -> string throws user.ApiError | user.ValidationError { + { : "ok" + if (x == 0 : bool) : void + { : never + throw ApiError { code: 404, message: "not found" } : user.ApiError + } + if (x == 1 : bool) : void + { : never + throw ValidationError { field: "id", reason: "required" } : user.ValidationError + } + "ok" : "ok" + } +} +function user.throw_enum_or_class(x: int) -> string throws user.ErrorKind | user.ApiError { + { : "ok" + if (x == 0 : bool) : void + { : never + throw ErrorKind.RateLimited : user.ErrorKind.RateLimited + } + if (x == 1 : bool) : void + { : never + throw ApiError { code: 503, message: "unavailable" } : user.ApiError + } + "ok" : "ok" + } + !! 1729..1750: throws contract violation: `user.ErrorKind | user.ApiError` is missing user.ErrorKind.RateLimited + ?? 1729..1750: extraneous throws declaration: user.ErrorKind +} +function user.throw_any_error(x: int) -> string throws string | user.ErrorKind | user.ApiError { + { : "ok" + match (x : int) : "ok" + 0 => + throw "simple string error" : never + 1 => + throw ErrorKind.NotFound : never + 2 => + throw ApiError { code: 400, message: "bad request" } : never + _ => + "ok" : "ok" + } + !! 1968..1998: throws contract violation: `string | user.ErrorKind | user.ApiError` is missing user.ErrorKind.NotFound + ?? 1968..1998: extraneous throws declaration: user.ErrorKind +} +function user.apply_may_throw_enum(f: () -> int throws user.ErrorKind) -> int throws user.ErrorKind { + { : int + f() : int + } +} +function user.test_apply_enum_thrower() -> int throws user.ErrorKind { + { : int + apply_may_throw_enum(() -> int { ... }) : int + () -> int { ... } : () -> never throws user.ErrorKind.Unauthorized + { + throw ErrorKind.Unauthorized + } + } +} +lambda user.test_apply_enum_thrower { +} +function user.apply_may_throw_class(f: () -> int throws user.ApiError) -> int throws user.ApiError { + { : int + f() : int + } +} +function user.test_apply_class_thrower() -> int throws user.ApiError { + { : int + apply_may_throw_class(() -> int { ... }) : int + () -> int { ... } : () -> never throws user.ApiError + { + throw ApiError { code: 401, message: "unauthorized" } + } + } +} +lambda user.test_apply_class_thrower { +} +function user.apply_generic_enum(f: () -> int throws __throws_f) -> int throws __throws_f { + { : int + f() : int + } +} +function user.test_rethrows_enum() -> int throws user.ErrorKind.NotFound { + { : int + apply_generic_enum(() -> int { ... }) : int + () -> int { ... } : () -> never throws user.ErrorKind.NotFound + { + throw ErrorKind.NotFound + } + } +} +lambda user.test_rethrows_enum { +} +function user.apply_generic_class(f: () -> int throws __throws_f) -> int throws __throws_f { + { : int + f() : int + } +} +function user.test_rethrows_class() -> int throws user.ValidationError { + { : int + apply_generic_class(() -> int { ... }) : int + () -> int { ... } : () -> never throws user.ValidationError + { + throw ValidationError { field: "name", reason: "too short" } + } + } +} +lambda user.test_rethrows_class { +} +function user.catch_enum_variants(x: int) -> string throws never { + { : string | "not found" | "unauthorized" | "other error" + catch (throw_enum_variant(x) : string) : unknown + catch (e) + ErrorKind.NotFound => + "not found" : "not found" + ErrorKind.Unauthorized => + "unauthorized" : "unauthorized" + _ => + "other error" : "other error" + } + ?? 3367..3381: unreachable arm + ?? 3392..3405: unreachable arm +} +function user.catch_class_error(x: int) -> string throws never { + { : string + catch (throw_class_instance(x) : string) : unknown + catch (e) + _: ApiError => + "api error: " + e.message : string + } +} +function user.catch_mixed_errors(x: int) -> string throws never { + { : string | "rate limited" | "api error" | "unknown" + catch (throw_enum_or_class(x) : string) : unknown + catch (e) + ErrorKind.RateLimited => + "rate limited" : "rate limited" + _: ApiError => + "api error" : "api error" + _ => + "unknown" : "unknown" + } + ?? 3721..3730: unreachable arm +} +class user.EnumThrowingHandler { + run: () -> null throws user.ErrorKind +} +class user.ClassThrowingHandler { + run: () -> null throws user.ApiError +} +class user.MixedThrowingHandler { + run: () -> null throws user.ErrorKind | user.ApiError +} +function user.make_enum_handler() -> user.EnumThrowingHandler throws never { + { : user.EnumThrowingHandler + EnumThrowingHandler { run: () -> null { ... } } : user.EnumThrowingHandler + } +} +lambda user.make_enum_handler { +} +function user.make_class_handler() -> user.ClassThrowingHandler throws never { + { : user.ClassThrowingHandler + ClassThrowingHandler { run: () -> null { ... } } : user.ClassThrowingHandler + } +} +lambda user.make_class_handler { +} +type user.EnumThrower = () -> int throws user.ErrorKind +type user.ClassThrower = () -> int throws user.ApiError +type user.MixedThrower = () -> int throws user.ErrorKind | user.ApiError | string +function user.use_enum_thrower(f: user.EnumThrower) -> int throws never { + { : int + f() : int + } +} +function user.use_class_thrower(f: user.ClassThrower) -> int throws never { + { : int + f() : int + } +} +function user.use_mixed_thrower(f: user.MixedThrower) -> int throws never { + { : int + f() : int + } +} +function user.test_type_alias_enum() -> int throws never { + { : int + use_enum_thrower(() -> int { ... }) : int + () -> int { ... } : () -> never throws user.ErrorKind.ValidationFailed + { + throw ErrorKind.ValidationFailed + } + } +} +lambda user.test_type_alias_enum { +} +function user.test_type_alias_class() -> int throws never { + { : int + use_class_thrower(() -> int { ... }) : int + () -> int { ... } : () -> never throws user.ApiError + { + throw ApiError { code: 422, message: "invalid" } + } + } +} +lambda user.test_type_alias_class { +} +function user.test_type_alias_mixed() -> int throws never { + { : int + use_mixed_thrower(() -> int { ... }) : int + () -> int { ... } : () -> never throws string + { + throw "string error" + } + } +} +lambda user.test_type_alias_mixed { +} +class user.ApiError$stream { + code: null | int + message: null | string +} +class user.ValidationError$stream { + field: null | string + reason: null | string +} +class user.EnumThrowingHandler$stream { + run: unknown +} +class user.ClassThrowingHandler$stream { + run: unknown +} +class user.MixedThrowingHandler$stream { + run: unknown +} +type user.EnumThrower$stream = unknown +type user.ClassThrower$stream = unknown +type user.MixedThrower$stream = unknown function user.apply_explicit(f: () -> int throws string) -> int throws string { { : int f() : int diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index 64b9a2d522..cc5e7bd3d2 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -2,6 +2,128 @@ source: crates/baml_tests/src/generated_tests.rs --- === COMPILER2 DIAGNOSTICS === + [type] Error: missing return: expected `(A) -> C` + ╭─[ compose_hof.baml:7:14 ] + │ + 7 │ ╭─▶ ) -> (A) -> C { + ┆ ┆ + 9 │ ├─▶ } + │ │ + │ ╰─────── missing return: expected `(A) -> C` + │ + │ Note: Error code: E0001 +───╯ + + [type] Warning: extraneous throws declaration: user.ErrorKind + ╭─[ enum_class_throws.baml:27:53 ] + │ + 27 │ function throw_enum_variant(x: int) -> string throws ErrorKind { + │ ─────┬──── + │ ╰────── extraneous throws declaration: user.ErrorKind + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized + ╭─[ enum_class_throws.baml:27:53 ] + │ + 27 │ function throw_enum_variant(x: int) -> string throws ErrorKind { + │ ─────┬──── + │ ╰────── throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized + │ + │ Note: Error code: E0001 +────╯ + + [type] Warning: extraneous throws declaration: user.ErrorKind + ╭─[ enum_class_throws.baml:60:55 ] + │ + 60 │ function throw_various_errors(x: int) -> string throws ErrorKind { + │ ─────┬──── + │ ╰────── extraneous throws declaration: user.ErrorKind + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized, user.ErrorKind.ValidationFailed + ╭─[ enum_class_throws.baml:60:55 ] + │ + 60 │ function throw_various_errors(x: int) -> string throws ErrorKind { + │ ─────┬──── + │ ╰────── throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized, user.ErrorKind.ValidationFailed + │ + │ Note: Error code: E0001 +────╯ + + [type] Warning: extraneous throws declaration: user.ErrorKind + ╭─[ enum_class_throws.baml:81:54 ] + │ + 81 │ function throw_enum_or_class(x: int) -> string throws ErrorKind | ApiError { + │ ──────────┬────────── + │ ╰──────────── extraneous throws declaration: user.ErrorKind + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `user.ErrorKind | user.ApiError` is missing user.ErrorKind.RateLimited + ╭─[ enum_class_throws.baml:81:54 ] + │ + 81 │ function throw_enum_or_class(x: int) -> string throws ErrorKind | ApiError { + │ ──────────┬────────── + │ ╰──────────── throws contract violation: `user.ErrorKind | user.ApiError` is missing user.ErrorKind.RateLimited + │ + │ Note: Error code: E0001 +────╯ + + [type] Warning: extraneous throws declaration: user.ErrorKind + ╭─[ enum_class_throws.baml:88:50 ] + │ + 88 │ function throw_any_error(x: int) -> string throws string | ErrorKind | ApiError { + │ ───────────────┬────────────── + │ ╰──────────────── extraneous throws declaration: user.ErrorKind + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `string | user.ErrorKind | user.ApiError` is missing user.ErrorKind.NotFound + ╭─[ enum_class_throws.baml:88:50 ] + │ + 88 │ function throw_any_error(x: int) -> string throws string | ErrorKind | ApiError { + │ ───────────────┬────────────── + │ ╰──────────────── throws contract violation: `string | user.ErrorKind | user.ApiError` is missing user.ErrorKind.NotFound + │ + │ Note: Error code: E0001 +────╯ + + [type] Warning: unreachable arm + ╭─[ enum_class_throws.baml:140:31 ] + │ + 140 │ ErrorKind.Unauthorized => "unauthorized", + │ ───────┬────── + │ ╰──────── unreachable arm + │ + │ Note: Error code: E0001 +─────╯ + + [type] Warning: unreachable arm + ╭─[ enum_class_throws.baml:141:10 ] + │ + 141 │ _ => "other error" + │ ──────┬────── + │ ╰──────── unreachable arm + │ + │ Note: Error code: E0001 +─────╯ + + [type] Warning: unreachable arm + ╭─[ enum_class_throws.baml:155:10 ] + │ + 155 │ _ => "unknown" + │ ────┬──── + │ ╰────── unreachable arm + │ + │ Note: Error code: E0001 +─────╯ + [validation] Error: Duplicate function `may_fail` ╭─[ catch_absorbs_throws.baml:3:10 ] │ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 8fb01140af..14b4230d46 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -37,6 +37,18 @@ function user.apply_explicit(f: () -> int) -> int { return } +function user.apply_generic_class(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.apply_generic_enum(f: () -> int) -> int { + load_var f + call_indirect + return +} + function user.apply_guarded(f: () -> int) -> int { load_var f call_indirect @@ -62,6 +74,18 @@ function user.apply_inner(f: () -> int) -> int { return } +function user.apply_may_throw_class(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.apply_may_throw_enum(f: () -> int) -> int { + load_var f + call_indirect + return +} + function user.apply_outer(g: () -> int) -> int { load_var g call user.apply_inner @@ -109,6 +133,106 @@ function user.caller() -> string { return } +function user.catch_class_error(x: int) -> string { + load_var x + call user.throw_class_instance + jump L2 + load_var e + load_const ApiError + cmp_op instanceof + pop_jump_if_false L0 + jump L1 + + L0: + load_var e + throw + + L1: + load_const "api error: " + load_var e + load_field .message + bin_op + + + L2: + return +} + +function user.catch_enum_variants(x: int) -> string { + load_var x + call user.throw_enum_variant + jump L4 + load_var e + load_const user.ErrorKind.NotFound + alloc_variant user.ErrorKind + cmp_op == + pop_jump_if_false L0 + jump L3 + + L0: + load_var e + load_const user.ErrorKind.Unauthorized + alloc_variant user.ErrorKind + cmp_op == + pop_jump_if_false L1 + jump L2 + + L1: + load_var e + throw_if_panic + load_const "other error" + jump L4 + + L2: + load_const "unauthorized" + jump L4 + + L3: + load_const "not found" + + L4: + return +} + +function user.catch_mixed_errors(x: int) -> string { + load_var x + call user.throw_enum_or_class + jump L4 + load_var e + load_const user.ErrorKind.RateLimited + alloc_variant user.ErrorKind + cmp_op == + pop_jump_if_false L0 + jump L3 + + L0: + load_var e + load_const ApiError + cmp_op instanceof + pop_jump_if_false L1 + jump L2 + + L1: + load_var e + throw_if_panic + load_const "unknown" + jump L4 + + L2: + load_const "api error" + jump L4 + + L3: + load_const "rate limited" + + L4: + return +} + +function user.compose(f: (void) -> void, g: (void) -> void) -> (void) -> void { + load_const null + return +} + function user.helper_with_body_throw() -> int { load_const "helper boom" throw @@ -122,6 +246,22 @@ function user.make_bad_stored_handler() -> StoredPureHandler { return } +function user.make_class_handler() -> ClassThrowingHandler { + alloc_instance ClassThrowingHandler + copy 0 + make_closure ., 0 + store_field .run + return +} + +function user.make_enum_handler() -> EnumThrowingHandler { + alloc_instance EnumThrowingHandler + copy 0 + make_closure ., 0 + store_field .run + return +} + function user.make_good_stored_handler() -> StoredPureHandler { alloc_instance StoredPureHandler copy 0 @@ -251,6 +391,18 @@ function user.test_apply_and_throw_pure() -> int { return } +function user.test_apply_class_thrower() -> int { + make_closure ., 0 + call user.apply_may_throw_class + return +} + +function user.test_apply_enum_thrower() -> int { + make_closure ., 0 + call user.apply_may_throw_enum + return +} + function user.test_apply_explicit_throws() -> int { make_closure ., 0 call user.apply_throwing @@ -340,6 +492,34 @@ function user.test_chained_throwing() -> int { return } +function user.test_compose_both_throw() -> (int) -> string { + make_closure ., 0 + make_closure ., 0 + call user.compose + return +} + +function user.test_compose_first_throws() -> (int) -> string { + make_closure ., 0 + make_closure ., 0 + call user.compose + return +} + +function user.test_compose_pure() -> (int) -> string { + make_closure ., 0 + make_closure ., 0 + call user.compose + return +} + +function user.test_compose_second_throws() -> (int) -> string { + make_closure ., 0 + make_closure ., 0 + call user.compose + return +} + function user.test_conditional_throw(x: int) -> int { load_var x make_closure ., 0 @@ -389,6 +569,18 @@ function user.test_guarded_throwing() -> int { return } +function user.test_lambda_throws_class() -> int { + make_closure ., 0 + call_indirect + return +} + +function user.test_lambda_throws_enum() -> int { + make_closure ., 0 + call_indirect + return +} + function user.test_many_args_pure() -> string { load_const 1 load_const "hello" @@ -468,6 +660,18 @@ function user.test_pure_only_pure() -> int { return } +function user.test_rethrows_class() -> int { + make_closure ., 0 + call user.apply_generic_class + return +} + +function user.test_rethrows_enum() -> int { + make_closure ., 0 + call user.apply_generic_enum + return +} + function user.test_run_pure() -> int { make_closure ., 0 call user.run_pure @@ -544,6 +748,24 @@ function user.test_two_pure() -> int { return } +function user.test_type_alias_class() -> int { + make_closure ., 0 + call user.use_class_thrower + return +} + +function user.test_type_alias_enum() -> int { + make_closure ., 0 + call user.use_enum_thrower + return +} + +function user.test_type_alias_mixed() -> int { + make_closure ., 0 + call user.use_mixed_thrower + return +} + function user.test_use_pure() -> int { call user.make_pure call_indirect @@ -555,3 +777,241 @@ function user.test_use_thrower() -> int { call_indirect return } + +function user.throw_any_error(x: int) -> string { + load_var x + copy 0 + load_const 0 + cmp_op == + pop_jump_if_false L0 + pop 1 + jump L5 + + L0: + copy 0 + load_const 1 + cmp_op == + pop_jump_if_false L1 + pop 1 + jump L4 + + L1: + copy 0 + load_const 2 + cmp_op == + pop_jump_if_false L2 + pop 1 + jump L3 + + L2: + pop 1 + load_const "ok" + return + + L3: + alloc_instance ApiError + copy 0 + load_const 400 + store_field .code + copy 0 + load_const "bad request" + store_field .message + throw + + L4: + load_const user.ErrorKind.NotFound + alloc_variant user.ErrorKind + throw + + L5: + load_const "simple string error" + throw +} + +function user.throw_class_instance(x: int) -> string { + load_var x + load_const 0 + cmp_op < + pop_jump_if_false L0 + jump L1 + + L0: + load_const "ok" + return + + L1: + alloc_instance ApiError + copy 0 + load_const 500 + store_field .code + copy 0 + load_const "internal error" + store_field .message + throw +} + +function user.throw_enum_or_class(x: int) -> string { + load_var x + load_const 0 + cmp_op == + pop_jump_if_false L0 + jump L3 + + L0: + load_var x + load_const 1 + cmp_op == + pop_jump_if_false L1 + jump L2 + + L1: + load_const "ok" + return + + L2: + alloc_instance ApiError + copy 0 + load_const 503 + store_field .code + copy 0 + load_const "unavailable" + store_field .message + throw + + L3: + load_const user.ErrorKind.RateLimited + alloc_variant user.ErrorKind + throw +} + +function user.throw_enum_variant(x: int) -> string { + load_var x + load_const 0 + cmp_op == + pop_jump_if_false L0 + jump L3 + + L0: + load_var x + load_const 1 + cmp_op == + pop_jump_if_false L1 + jump L2 + + L1: + load_const "ok" + return + + L2: + load_const user.ErrorKind.Unauthorized + alloc_variant user.ErrorKind + throw + + L3: + load_const user.ErrorKind.NotFound + alloc_variant user.ErrorKind + throw +} + +function user.throw_mixed_classes(x: int) -> string { + load_var x + load_const 0 + cmp_op == + pop_jump_if_false L0 + jump L3 + + L0: + load_var x + load_const 1 + cmp_op == + pop_jump_if_false L1 + jump L2 + + L1: + load_const "ok" + return + + L2: + alloc_instance ValidationError + copy 0 + load_const "id" + store_field .field + copy 0 + load_const "required" + store_field .reason + throw + + L3: + alloc_instance ApiError + copy 0 + load_const 404 + store_field .code + copy 0 + load_const "not found" + store_field .message + throw +} + +function user.throw_various_errors(x: int) -> string { + load_var x + copy 0 + load_const 0 + cmp_op == + pop_jump_if_false L0 + pop 1 + jump L5 + + L0: + copy 0 + load_const 1 + cmp_op == + pop_jump_if_false L1 + pop 1 + jump L4 + + L1: + copy 0 + load_const 2 + cmp_op == + pop_jump_if_false L2 + pop 1 + jump L3 + + L2: + pop 1 + load_const "ok" + return + + L3: + load_const user.ErrorKind.ValidationFailed + alloc_variant user.ErrorKind + throw + + L4: + load_const user.ErrorKind.Unauthorized + alloc_variant user.ErrorKind + throw + + L5: + load_const user.ErrorKind.NotFound + alloc_variant user.ErrorKind + throw +} + +function user.use_class_thrower(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.use_enum_thrower(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.use_mixed_thrower(f: () -> int) -> int { + load_var f + call_indirect + return +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__compose_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__compose_hof.snap new file mode 100644 index 0000000000..775dfe8022 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__compose_hof.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Expected token/node of kind PARAMETER_LIST, but found GENERIC_PARAM_LIST at compose_hof.baml:4:17 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__enum_class_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__enum_class_throws.snap new file mode 100644 index 0000000000..28611a6344 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__enum_class_throws.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Expected token/node of kind LLM_FUNCTION_BODY or EXPR_FUNCTION_BODY, but found THROWS_CLAUSE at enum_class_throws.baml:27:46 From 84d09cb546a61742248d2278f94a16d65a2b82bb Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 09:05:35 -0500 Subject: [PATCH 06/26] Refine typed throws semantics for higher-order functions --- .../crates/baml_compiler2_tir/src/builder.rs | 221 +++++++--------- .../src/effective_throws.rs | 62 ++++- .../baml_compiler2_tir/src/inference.rs | 32 +-- .../crates/baml_compiler2_tir/src/lib.rs | 1 + .../baml_compiler2_tir/src/throw_inference.rs | 128 +++++++--- .../src/throws_semantics.rs | 237 ++++++++++++++++++ .../function_type_throws/compose_hof.baml | 2 +- .../returned_closures.baml | 4 +- .../stored_callback_enforcement.baml | 7 + ...on_type_throws__01_lexer__compose_hof.snap | 1 + ...e_throws__01_lexer__returned_closures.snap | 2 + ...01_lexer__stored_callback_enforcement.snap | 48 ++++ ...n_type_throws__02_parser__compose_hof.snap | 54 ++-- ..._throws__02_parser__returned_closures.snap | 60 ++--- ...2_parser__stored_callback_enforcement.snap | 45 ++++ ...l_tests__function_type_throws__03_hir.snap | 12 +- ...tests__function_type_throws__04_5_mir.snap | 87 ++++++- ...l_tests__function_type_throws__04_tir.snap | 76 +++--- ..._function_type_throws__05_diagnostics.snap | 164 ++---------- ...sts__function_type_throws__06_codegen.snap | 19 +- .../baml_tests/src/compiler2_tir/mod.rs | 63 ++--- 21 files changed, 840 insertions(+), 485 deletions(-) create mode 100644 baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index 1ad4f8a5bf..513b9a0a52 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -28,6 +28,7 @@ use text_size::TextRange; use crate::{ infer_context::{InferContext, RelatedLocation, TirTypeError, TypeCheckDiagnostics}, package_interface::PackageResolutionContext, + throws_semantics, ty::{Freshness, PrimitiveType, Ty, TyAttr}, }; @@ -936,19 +937,7 @@ impl<'db> TypeInferenceBuilder<'db> { // Expand type alias chains so alias-over-function types are callable. // Bare alias cycles are already caught by find_invalid_alias_cycles (Tarjan SCC) // before we reach here, so the depth guard is cheap insurance. - let callee_ty = { - let mut ty = callee_ty; - for _ in 0..64 { - match &ty { - Ty::TypeAlias(qtn, _) => match self.aliases.get(qtn) { - Some(expanded) => ty = expanded.clone(), - None => break, - }, - _ => break, - } - } - ty - }; + let callee_ty = throws_semantics::resolve_alias_chain(&callee_ty, &self.aliases); match &callee_ty { Ty::Function { @@ -1114,7 +1103,8 @@ impl<'db> TypeInferenceBuilder<'db> { // Lambda: bidirectional checking against expected function type Expr::Lambda(func_def) => { // Resolve type aliases so we can decompose function types - let expected_resolved = self.resolve_alias_chain(expected); + let expected_resolved = + throws_semantics::resolve_alias_chain(expected, &self.aliases); match &expected_resolved { Ty::Function { params: expected_params, @@ -1237,6 +1227,11 @@ impl<'db> TypeInferenceBuilder<'db> { // lambda body throws, report the stored function diagnostic if matches!(expected_throws.as_ref(), Ty::Never { .. }) && !matches!(throws_ty, Ty::Never { .. }) + && throws_semantics::function_shape_matches_ignoring_outer_throws( + &result, + &expected_resolved, + &self.aliases, + ) { self.context.report_simple( TirTypeError::StoredFunctionRequiresExplicitThrows { @@ -1869,16 +1864,18 @@ impl<'db> TypeInferenceBuilder<'db> { self.context.report_at_span(diag, span); } - let declared = crate::throw_inference::flatten_ty_to_facts(&declared_ty); let effective = self.collect_effective_throws(body); + let diff = throws_semantics::throws_contract_diff(&declared_ty, &effective, &self.aliases); - let mut extra: Vec = effective - .difference(&declared) - .map(std::string::ToString::to_string) + let mut extra: Vec = diff + .uncovered_effective + .into_iter() + .map(|ty| ty.to_string()) .collect(); - let mut extraneous: Vec = declared - .difference(&effective) - .map(std::string::ToString::to_string) + let mut extraneous: Vec = diff + .extraneous_declared + .into_iter() + .map(|ty| ty.to_string()) .collect(); extra.sort(); extraneous.sort(); @@ -2412,7 +2409,7 @@ impl<'db> TypeInferenceBuilder<'db> { at_expr, ) }; - if Self::ty_covers_fact(&lowered, throw_fact) { + if self.ty_covers_fact(&lowered, throw_fact) { PatternMatchStrength::DefiniteMatch } else if is_unknown { PatternMatchStrength::MayMatch @@ -2425,7 +2422,7 @@ impl<'db> TypeInferenceBuilder<'db> { } baml_compiler2_ast::Pattern::TypedBinding { ty, .. } => { let lowered = self.lower_pattern_type_expr(ty, at_expr); - if Self::ty_covers_fact(&lowered, throw_fact) { + if self.ty_covers_fact(&lowered, throw_fact) { PatternMatchStrength::DefiniteMatch } else if is_unknown { PatternMatchStrength::MayMatch @@ -2454,13 +2451,19 @@ impl<'db> TypeInferenceBuilder<'db> { } } baml_compiler2_ast::Pattern::EnumVariant { enum_name, variant } => { - let matches_variant = match throw_fact { - Ty::EnumVariant(qn, v, _) => { - Self::enum_name_matches(enum_name, qn) && v == variant - } - Ty::Enum(qn, _) => Self::enum_name_matches(enum_name, qn), - _ => false, - }; + let matches_variant = + match throws_semantics::resolve_alias_chain(throw_fact, &self.aliases) { + Ty::EnumVariant(qn, v, _) => { + Self::enum_name_matches(enum_name, &qn) && v == *variant + } + Ty::Enum(qn, _) => { + if Self::enum_name_matches(enum_name, &qn) { + return PatternMatchStrength::MayMatch; + } + false + } + _ => false, + }; if matches_variant { PatternMatchStrength::DefiniteMatch } else if is_unknown { @@ -2489,37 +2492,8 @@ impl<'db> TypeInferenceBuilder<'db> { } } - /// Check if a pattern type covers a throw fact type. - fn ty_covers_fact(pattern_ty: &Ty, fact: &Ty) -> bool { - match pattern_ty { - Ty::Primitive(p, _) => match fact { - Ty::Primitive(fp, _) => p == fp, - Ty::Literal(lit, _, _) => *p == PrimitiveType::from_literal(lit), - _ => false, - }, - Ty::Literal(lit, _, _) => { - let widened = Ty::Primitive(PrimitiveType::from_literal(lit), TyAttr::default()); - &widened == fact - } - Ty::Optional(inner, _) => { - matches!(fact, Ty::Primitive(PrimitiveType::Null, _)) - || Self::ty_covers_fact(inner, fact) - } - Ty::Union(parts, _) => parts.iter().any(|part| Self::ty_covers_fact(part, fact)), - Ty::Class(qn, _) => matches!(fact, Ty::Class(fqn, _) if fqn == qn), - Ty::Enum(qn, _) => match fact { - Ty::Enum(fqn, _) => fqn == qn, - Ty::EnumVariant(fqn, _, _) => fqn == qn, - _ => false, - }, - Ty::TypeAlias(qn, _) => matches!(fact, Ty::TypeAlias(fqn, _) if fqn == qn), - Ty::EnumVariant(qn, variant, _) => { - matches!(fact, Ty::EnumVariant(fqn, fv, _) if fqn == qn && fv == variant) - || matches!(fact, Ty::Enum(fqn, _) if fqn == qn) - } - Ty::BuiltinUnknown { .. } | Ty::Unknown { .. } | Ty::Error { .. } => true, - _ => false, - } + fn ty_covers_fact(&self, pattern_ty: &Ty, fact: &Ty) -> bool { + throws_semantics::type_covers_throw_fact(pattern_ty, fact, &self.aliases) } fn collect_effective_throws(&self, body: &ExprBody) -> BTreeSet { @@ -2529,6 +2503,7 @@ impl<'db> TypeInferenceBuilder<'db> { body, &self.expressions, &self.catch_residual_throws, + &self.aliases, true, true, ) @@ -2550,13 +2525,11 @@ impl<'db> TypeInferenceBuilder<'db> { for arg in args { self.collect_throw_facts_from_expr(*arg, body, out); } - let type_level_facts = self.expressions.get(callee).and_then(|ty| match ty { - Ty::Function { throws, .. } => { - let facts = crate::throw_inference::flatten_ty_to_facts(throws); - if facts.is_empty() { None } else { Some(facts) } - } - _ => None, - }); + let type_level_facts = self + .expressions + .get(callee) + .and_then(|ty| throws_semantics::function_throws_facts(ty, &self.aliases)) + .and_then(|facts| if facts.is_empty() { None } else { Some(facts) }); if let Some(facts) = type_level_facts { out.extend(facts); } else if let Some(target) = self.call_target_name(*callee, body) { @@ -2719,14 +2692,16 @@ impl<'db> TypeInferenceBuilder<'db> { out: &mut BTreeSet, ) { if let Some(thrown_ty) = self.expressions.get(&value_expr_id) { - out.extend(crate::throw_inference::flatten_ty_to_facts(thrown_ty)); + out.extend(throws_semantics::flatten_ty_to_facts(thrown_ty)); return; } match &body.exprs[value_expr_id] { - Expr::Literal(lit) => out.extend(crate::throw_inference::flatten_ty_to_facts( - &Ty::Literal(lit.clone(), Freshness::Regular, TyAttr::default()), - )), + Expr::Literal(lit) => out.extend(throws_semantics::flatten_ty_to_facts(&Ty::Literal( + lit.clone(), + Freshness::Regular, + TyAttr::default(), + ))), Expr::ByteStringLiteral(_) => { out.insert(Ty::Primitive(PrimitiveType::Uint8Array, TyAttr::default())); } @@ -2812,7 +2787,7 @@ impl<'db> TypeInferenceBuilder<'db> { let db = self.context.db(); if let baml_compiler2_hir::body::FunctionBody::Expr(ref expr_body) = *body { let scope_inference = crate::inference::infer_scope_types(db, body_scope); - crate::throw_inference::flatten_ty_to_facts(&scope_inference.effective_throws( + throws_semantics::flatten_ty_to_facts(&scope_inference.effective_throws( db, body_package_id, expr_body, @@ -2825,26 +2800,6 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn combine_effect_vars_with_body_throws( - synthetic_effect_vars: &[Name], - body_throws_facts: BTreeSet, - ) -> Ty { - let mut all_throws: Vec = synthetic_effect_vars - .iter() - .map(|v| Ty::TypeVar(v.clone(), TyAttr::default())) - .collect(); - all_throws.extend(body_throws_facts); - all_throws.retain(|t| !matches!(t, Ty::Never { .. } | Ty::Void { .. })); - - match all_throws.len() { - 0 => Ty::Never { - attr: TyAttr::default(), - }, - 1 => all_throws.remove(0), - _ => Ty::Union(all_throws, TyAttr::default()), - } - } - #[allow(clippy::too_many_arguments)] fn build_function_ty_from_signature( &self, @@ -2852,6 +2807,7 @@ impl<'db> TypeInferenceBuilder<'db> { ns_context: &[Name], generic_params: &[Name], sig: &baml_compiler2_hir::signature::FunctionSignature, + function_key: Option<&Name>, body_scope: Option>, body_package_id: Option>, body: Option<&baml_compiler2_hir::body::FunctionBody>, @@ -2926,25 +2882,33 @@ impl<'db> TypeInferenceBuilder<'db> { ) }) .unwrap_or_else(|| { - let has_callback_param = params - .iter() - .any(|(_, ty)| matches!(ty, Ty::Function { .. })); - if !has_callback_param { - return Ty::Never { - attr: TyAttr::default(), + if synthetic_effect_vars.is_empty() { + match (function_key, body_package_id) { + (Some(function_key), Some(body_package_id)) => { + crate::throw_inference::function_throw_sets(db, body_package_id) + .transitive_for(function_key) + .cloned() + .map(throws_semantics::concrete_throws_ty_from_facts) + .unwrap_or(Ty::Never { + attr: TyAttr::default(), + }) + } + _ => Ty::Never { + attr: TyAttr::default(), + }, + } + } else { + let body_throws_facts = match (body_scope, body_package_id, body) { + (Some(body_scope), Some(body_package_id), Some(body)) => { + self.infer_concrete_body_throws(body_scope, body_package_id, body) + } + _ => BTreeSet::new(), }; + throws_semantics::combine_effect_vars_with_body_throws( + &synthetic_effect_vars, + body_throws_facts, + ) } - - let body_throws_facts = match (body_scope, body_package_id, body) { - (Some(body_scope), Some(body_package_id), Some(body)) => { - self.infer_concrete_body_throws(body_scope, body_package_id, body) - } - _ => BTreeSet::new(), - }; - Self::combine_effect_vars_with_body_throws( - &synthetic_effect_vars, - body_throws_facts, - ) }); Ty::Function { @@ -3089,6 +3053,10 @@ impl<'db> TypeInferenceBuilder<'db> { let func_data_for_sig = &item_tree_for_func[func_loc.id(db)]; let generic_params = &func_data_for_sig.generic_params; let pkg_info = baml_compiler2_hir::file_package::file_package(db, func_loc.file(db)); + let function_key = crate::throw_inference::throw_set_key( + &pkg_info.namespace_path, + &func_data_for_sig.name, + ); let ns_context = pkg_info.namespace_path; self.resolutions.insert( expr_id, @@ -3106,6 +3074,7 @@ impl<'db> TypeInferenceBuilder<'db> { &ns_context, generic_params, sig.as_ref(), + Some(&function_key), Some(func_scope), Some(PackageId::new(db, pkg_info.package)), Some(func_body.as_ref()), @@ -3158,6 +3127,8 @@ impl<'db> TypeInferenceBuilder<'db> { let sig_pkg = baml_compiler2_hir::file_package::file_package(db, func_loc.file(db)); let sig_ns = sig_pkg.namespace_path; + let function_key = + crate::throw_inference::throw_set_key(&sig_ns, &func_data.name); let func_scope = self.find_function_scope_id( func_loc.file(db), func_data.span, @@ -3169,6 +3140,7 @@ impl<'db> TypeInferenceBuilder<'db> { &sig_ns, generic_params, sig.as_ref(), + Some(&function_key), Some(func_scope), Some(PackageId::new(db, sig_pkg.package)), Some(func_body.as_ref()), @@ -3636,6 +3608,10 @@ impl<'db> TypeInferenceBuilder<'db> { let func_loc = baml_compiler2_hir::loc::FunctionLoc::new(db, file, method_id); let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); let class_ty = Ty::Class(class_name.clone(), TyAttr::default()); + let function_key = crate::throw_inference::throw_set_key( + &ns_context, + &Name::new(format!("{}.{}", class_name.name(), method_data.name)), + ); let method_scope = self.find_function_scope_id(file, method_data.span, &method_data.name); let method_body = baml_compiler2_hir::body::function_body(db, func_loc); @@ -3644,6 +3620,7 @@ impl<'db> TypeInferenceBuilder<'db> { &ns_context, &all_generic_params, sig.as_ref(), + Some(&function_key), Some(method_scope), Some(PackageId::new( db, @@ -4493,8 +4470,8 @@ impl<'db> TypeInferenceBuilder<'db> { /// explaining that explicit `throws` annotation is required. fn report_type_mismatch(&self, expected: &Ty, got: &Ty, at: ExprId) { // Resolve type aliases for comparison - let expected_resolved = self.resolve_alias_chain(expected); - let got_resolved = self.resolve_alias_chain(got); + let expected_resolved = throws_semantics::resolve_alias_chain(expected, &self.aliases); + let got_resolved = throws_semantics::resolve_alias_chain(got, &self.aliases); // Check for stored function throws mismatch if let ( @@ -4512,6 +4489,11 @@ impl<'db> TypeInferenceBuilder<'db> { // emit the specific stored-function diagnostic if matches!(expected_throws.as_ref(), Ty::Never { .. }) && !matches!(actual_throws.as_ref(), Ty::Never { .. }) + && throws_semantics::function_shape_matches_ignoring_outer_throws( + got, + expected, + &self.aliases, + ) { self.context.report_simple( TirTypeError::StoredFunctionRequiresExplicitThrows { @@ -4534,21 +4516,6 @@ impl<'db> TypeInferenceBuilder<'db> { ); } - /// Resolve a type alias chain to its underlying type (up to a depth limit). - fn resolve_alias_chain(&self, ty: &Ty) -> Ty { - let mut resolved = ty.clone(); - for _ in 0..64 { - match &resolved { - Ty::TypeAlias(qtn, _) => match self.aliases.get(qtn) { - Some(expanded) => resolved = expanded.clone(), - None => break, - }, - _ => break, - } - } - resolved - } - fn infer_binary_op( &mut self, op: baml_compiler2_ast::BinaryOp, diff --git a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs index 2873e5e9d4..ef2aaaea2f 100644 --- a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs +++ b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashMap}; use baml_base::Name; use baml_compiler2_ast::{Expr, ExprBody, ExprId, Stmt, StmtId}; @@ -6,16 +6,19 @@ use baml_compiler2_hir::package::PackageId; use rustc_hash::FxHashMap; use crate::{ - throw_inference::{flatten_ty_to_facts, function_throw_sets}, + throw_inference::function_throw_sets, + throws_semantics::{flatten_ty_to_facts, function_throws_facts}, ty::{Freshness, QualifiedTypeName, Ty, TyAttr}, }; +#[allow(clippy::too_many_arguments)] pub(crate) fn collect_effective_throws<'db>( db: &'db dyn crate::Db, package_id: PackageId<'db>, body: &ExprBody, expressions: &FxHashMap, catch_residual_throws: &FxHashMap>, + aliases: &HashMap, include_typevars: bool, unknown_on_unresolved_call: bool, ) -> BTreeSet { @@ -28,6 +31,7 @@ pub(crate) fn collect_effective_throws<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, &mut out, @@ -50,6 +54,7 @@ fn collect_effective_throws_from_expr<'db>( body: &ExprBody, expressions: &FxHashMap, catch_residual_throws: &FxHashMap>, + aliases: &HashMap, include_typevars: bool, unknown_on_unresolved_call: bool, out: &mut BTreeSet, @@ -63,6 +68,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -77,6 +83,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -89,6 +96,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -100,6 +108,7 @@ fn collect_effective_throws_from_expr<'db>( *callee, body, expressions, + aliases, CallResolutionOptions { include_typevars, unknown_on_unresolved_call, @@ -121,6 +130,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -140,6 +150,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -151,6 +162,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -163,6 +175,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -179,6 +192,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -193,6 +207,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -205,6 +220,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -219,6 +235,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -230,6 +247,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -243,6 +261,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -259,6 +278,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -272,6 +292,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -287,6 +308,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -302,6 +324,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -313,6 +336,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -328,6 +352,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -341,6 +366,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -355,6 +381,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -368,6 +395,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -379,6 +407,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -392,6 +421,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -404,6 +434,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -418,6 +449,7 @@ fn collect_effective_throws_from_expr<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -440,6 +472,7 @@ fn collect_effective_throws_from_stmt<'db>( body: &ExprBody, expressions: &FxHashMap, catch_residual_throws: &FxHashMap>, + aliases: &HashMap, include_typevars: bool, unknown_on_unresolved_call: bool, out: &mut BTreeSet, @@ -452,6 +485,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -465,6 +499,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -484,6 +519,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -495,6 +531,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -507,6 +544,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -525,6 +563,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -536,6 +575,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -550,6 +590,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -564,6 +605,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -575,6 +617,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -588,6 +631,7 @@ fn collect_effective_throws_from_stmt<'db>( body, expressions, catch_residual_throws, + aliases, include_typevars, unknown_on_unresolved_call, out, @@ -598,19 +642,20 @@ fn collect_effective_throws_from_stmt<'db>( } } +#[allow(clippy::too_many_arguments)] fn collect_effective_throws_from_call<'db>( db: &'db dyn crate::Db, package_id: PackageId<'db>, callee_expr_id: ExprId, body: &ExprBody, expressions: &FxHashMap, + aliases: &HashMap, options: CallResolutionOptions, out: &mut BTreeSet, ) { - let type_level_facts = expressions.get(&callee_expr_id).and_then(|ty| match ty { - Ty::Function { throws, .. } => Some(flatten_ty_to_facts(throws)), - _ => None, - }); + let type_level_facts = expressions + .get(&callee_expr_id) + .and_then(|ty| function_throws_facts(ty, aliases)); if let Some(facts) = type_level_facts { let filtered: BTreeSet = facts @@ -636,7 +681,10 @@ fn collect_effective_throws_from_call<'db>( } if options.unknown_on_unresolved_call - && !matches!(expressions.get(&callee_expr_id), Some(Ty::Function { .. })) + && expressions + .get(&callee_expr_id) + .and_then(|ty| function_throws_facts(ty, aliases)) + .is_none() { out.insert(Ty::Unknown { attr: TyAttr::default(), diff --git a/baml_language/crates/baml_compiler2_tir/src/inference.rs b/baml_language/crates/baml_compiler2_tir/src/inference.rs index c09821dbc0..361cc42547 100644 --- a/baml_language/crates/baml_compiler2_tir/src/inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/inference.rs @@ -172,42 +172,20 @@ impl<'db> ScopeInference<'db> { package_id: PackageId<'db>, body: &baml_compiler2_ast::ExprBody, ) -> crate::ty::Ty { - use std::collections::BTreeSet; + let pkg_items = baml_compiler2_ppir::package_items(db, package_id); + let aliases = collect_type_aliases(db, pkg_items); - use crate::ty::{PrimitiveType, TyAttr}; - - let mut facts: BTreeSet = crate::effective_throws::collect_effective_throws( + let facts = crate::effective_throws::collect_effective_throws( db, package_id, body, &self.expressions, &self.catch_residual_throws, + &aliases, false, false, ); - - // Remove Never and Void facts (they don't represent thrown exceptions). - facts.retain(|f| !matches!(f, Ty::Never { .. } | Ty::Void { .. })); - - // Widen string literals to string primitive (matches throw_inference behavior). - let widened: BTreeSet = facts - .into_iter() - .map(|f| match &f { - Ty::Literal(baml_compiler2_ast::Literal::String(_), _, _) => { - Ty::Primitive(PrimitiveType::String, TyAttr::default()) - } - other => other.clone(), - }) - .collect(); - - let mut members: Vec = widened.into_iter().collect(); - match members.len() { - 0 => Ty::Never { - attr: TyAttr::default(), - }, - 1 => members.remove(0), - _ => Ty::Union(members, TyAttr::default()), - } + crate::throws_semantics::concrete_throws_ty_from_facts(facts) } /// Get diagnostics for this scope (empty slice if none). diff --git a/baml_language/crates/baml_compiler2_tir/src/lib.rs b/baml_language/crates/baml_compiler2_tir/src/lib.rs index abd7fbf431..bc99ba0df8 100644 --- a/baml_language/crates/baml_compiler2_tir/src/lib.rs +++ b/baml_language/crates/baml_compiler2_tir/src/lib.rs @@ -29,6 +29,7 @@ pub mod normalize; pub mod package_interface; pub mod resolve; pub mod throw_inference; +pub mod throws_semantics; pub mod ty; // ── Db trait ────────────────────────────────────────────────────────────────── diff --git a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs index dab2abc9eb..33f1135c7f 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs @@ -5,7 +5,7 @@ //! firewalls: their declared set becomes caller-visible, replacing body-derived //! facts for propagation. -use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use baml_base::Name; use baml_compiler2_ast::{Expr, ExprBody, Literal, Pattern, TypeExpr}; @@ -15,7 +15,9 @@ use baml_compiler2_hir::{ }; use crate::{ + inference::collect_type_aliases, lower_type_expr::{lower_type_expr_in_ns, qualify_def}, + throws_semantics::function_throws_facts, ty::{PrimitiveType, Ty, TyAttr}, }; @@ -64,6 +66,7 @@ pub fn function_throw_sets<'db>( package_id: PackageId<'db>, ) -> FunctionThrowSets { let pkg_items = package_items(db, package_id); + let aliases = collect_type_aliases(db, pkg_items); // Load dependency interfaces for cross-package throw lookup let dep_interfaces: Vec<(Name, &crate::package_interface::PackageInterface)> = package_dependencies(db, package_id) @@ -92,13 +95,13 @@ pub fn function_throw_sets<'db>( let key = function_key(db, *func_loc, short_name); let sig = baml_compiler2_hir::signature::function_signature(db, *func_loc); let body = baml_compiler2_hir::body::function_body(db, *func_loc); + let item_tree = baml_compiler2_hir::file_item_tree(db, func_loc.file(db)); + let func_data = &item_tree[func_loc.id(db)]; let func_ns = baml_compiler2_hir::file_package::file_package(db, func_loc.file(db)) .namespace_path; let declared_throws = sig.throws.as_ref().map(|te| { let mut diags = Vec::new(); - let item_tree = baml_compiler2_hir::file_item_tree(db, func_loc.file(db)); - let func_data = &item_tree[func_loc.id(db)]; let lowered = lower_type_expr_in_ns( db, te, @@ -114,7 +117,17 @@ pub fn function_throw_sets<'db>( let direct = if let Some(declared) = declared_throws.clone() { declared } else if let baml_compiler2_hir::body::FunctionBody::Expr(expr_body) = body.as_ref() { - collect_direct_throws(db, pkg_items, &func_ns, expr_body) + let mut direct = collect_direct_throws(db, pkg_items, &func_ns, expr_body); + direct.extend(collect_direct_param_call_throws( + db, + pkg_items, + &func_ns, + &func_data.generic_params, + sig.as_ref(), + expr_body, + &aliases, + )); + direct } else { BTreeSet::new() }; @@ -149,6 +162,8 @@ pub fn function_throw_sets<'db>( let method_ns = baml_compiler2_hir::file_package::file_package(db, file).namespace_path; + let mut method_generic_params = class_data.generic_params.clone(); + method_generic_params.extend(method_data.generic_params.iter().cloned()); let declared_throws = sig.throws.as_ref().map(|te| { let mut diags = Vec::new(); let lowered = lower_type_expr_in_ns( @@ -156,7 +171,7 @@ pub fn function_throw_sets<'db>( te, pkg_items, &method_ns, - &method_data.generic_params, + &method_generic_params, &mut diags, ); drop(diags); @@ -168,7 +183,17 @@ pub fn function_throw_sets<'db>( } else if let baml_compiler2_hir::body::FunctionBody::Expr(expr_body) = body.as_ref() { - collect_direct_throws(db, pkg_items, &method_ns, expr_body) + let mut direct = collect_direct_throws(db, pkg_items, &method_ns, expr_body); + direct.extend(collect_direct_param_call_throws( + db, + pkg_items, + &method_ns, + &method_generic_params, + sig.as_ref(), + expr_body, + &aliases, + )); + direct } else { BTreeSet::new() }; @@ -482,40 +507,75 @@ fn expr_to_path(expr_id: baml_compiler2_ast::ExprId, body: &ExprBody) -> Option< } } -/// Flatten a compound `Ty` into its leaf throw facts. -/// Unions and optionals are decomposed; leaf types are kept as-is. pub fn flatten_ty_to_facts(ty: &Ty) -> BTreeSet { - let mut out = BTreeSet::new(); - collect_leaf_types(ty, &mut out); - out + crate::throws_semantics::flatten_ty_to_facts(ty) } -fn collect_leaf_types(ty: &Ty, out: &mut BTreeSet) { - match ty { - // Compound types: decompose - Ty::Optional(inner, _) => { - collect_leaf_types(inner, out); - out.insert(Ty::Primitive(PrimitiveType::Null, TyAttr::default())); - } - Ty::Union(members, _) => { - for member in members { - collect_leaf_types(member, out); - } - } - // Literal types: widen to primitive for throw fact purposes - Ty::Literal(lit, _, _) => { - out.insert(Ty::Primitive( - PrimitiveType::from_literal(lit), - TyAttr::default(), - )); +fn collect_direct_param_call_throws<'db>( + db: &'db dyn crate::Db, + pkg_items: &PackageItems<'db>, + ns_context: &[Name], + generic_params: &[Name], + sig: &baml_compiler2_hir::signature::FunctionSignature, + body: &ExprBody, + aliases: &HashMap, +) -> BTreeSet { + let mut direct_param_throws: HashMap> = HashMap::new(); + + for (param_name, param_ty_expr) in &sig.params { + let mut diags = Vec::new(); + let lowered = lower_type_expr_in_ns( + db, + param_ty_expr, + pkg_items, + ns_context, + generic_params, + &mut diags, + ); + drop(diags); + + let Some(facts) = function_throws_facts(&lowered, aliases) else { + continue; + }; + + let concrete_facts: BTreeSet = facts + .into_iter() + .filter(|fact| !matches!(fact, Ty::TypeVar(_, _) | Ty::Never { .. } | Ty::Void { .. })) + .collect(); + + if !concrete_facts.is_empty() { + direct_param_throws.insert(param_name.clone(), concrete_facts); } - // Bottom/void: no facts - Ty::Never { .. } | Ty::Void { .. } => {} - // Everything else: keep as-is - _ => { - out.insert(ty.clone()); + } + + if direct_param_throws.is_empty() { + return BTreeSet::new(); + } + + let mut out = BTreeSet::new(); + for (_, expr) in body.exprs.iter() { + let callee = match expr { + Expr::Call { callee, .. } | Expr::OptionalCall { callee, .. } => *callee, + _ => continue, + }; + + let Some(param_name) = direct_param_callee_name(callee, body) else { + continue; + }; + + if let Some(facts) = direct_param_throws.get(¶m_name) { + out.extend(facts.iter().cloned()); } } + + out +} + +fn direct_param_callee_name(expr_id: baml_compiler2_ast::ExprId, body: &ExprBody) -> Option { + match &body.exprs[expr_id] { + Expr::Path(segments) if segments.len() == 1 => Some(segments[0].clone()), + _ => None, + } } /// Look up a function's transitive throw set from dependency interfaces. diff --git a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs new file mode 100644 index 0000000000..d352ff75fc --- /dev/null +++ b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs @@ -0,0 +1,237 @@ +use std::collections::{BTreeSet, HashMap}; + +use baml_base::Name; + +use crate::{ + normalize, + ty::{PrimitiveType, QualifiedTypeName, Ty, TyAttr}, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ThrowsContractDiff { + pub uncovered_effective: Vec, + pub extraneous_declared: Vec, +} + +pub(crate) fn resolve_alias_chain(ty: &Ty, aliases: &HashMap) -> Ty { + let mut resolved = ty.clone(); + for _ in 0..64 { + match &resolved { + Ty::TypeAlias(qtn, _) => match aliases.get(qtn) { + Some(expanded) => resolved = expanded.clone(), + None => break, + }, + _ => break, + } + } + resolved +} + +pub(crate) fn function_throws_facts( + ty: &Ty, + aliases: &HashMap, +) -> Option> { + match resolve_alias_chain(ty, aliases) { + Ty::Function { throws, .. } => Some(flatten_ty_to_facts(&throws)), + _ => None, + } +} + +/// Flatten a compound `Ty` into its leaf throw facts. +/// Unions and optionals are decomposed; leaf types are kept as-is. +pub fn flatten_ty_to_facts(ty: &Ty) -> BTreeSet { + let mut out = BTreeSet::new(); + collect_leaf_types(ty, &mut out); + out +} + +fn collect_leaf_types(ty: &Ty, out: &mut BTreeSet) { + match ty { + Ty::Optional(inner, _) => { + collect_leaf_types(inner, out); + out.insert(Ty::Primitive(PrimitiveType::Null, TyAttr::default())); + } + Ty::Union(members, _) => { + for member in members { + collect_leaf_types(member, out); + } + } + Ty::Literal(lit, _, _) => { + out.insert(Ty::Primitive( + PrimitiveType::from_literal(lit), + TyAttr::default(), + )); + } + Ty::Never { .. } | Ty::Void { .. } => {} + _ => { + out.insert(ty.clone()); + } + } +} + +pub(crate) fn throws_contract_diff( + declared_ty: &Ty, + effective_facts: &BTreeSet, + aliases: &HashMap, +) -> ThrowsContractDiff { + let declared_facts = flatten_ty_to_facts(declared_ty); + + let uncovered_effective = effective_facts + .iter() + .filter(|fact| { + !declared_facts + .iter() + .any(|declared| declared_covers_fact(declared, fact, aliases)) + }) + .cloned() + .collect(); + + let extraneous_declared = declared_facts + .iter() + .filter(|declared| { + !effective_facts + .iter() + .any(|fact| declared_covers_fact(declared, fact, aliases)) + }) + .cloned() + .collect(); + + ThrowsContractDiff { + uncovered_effective, + extraneous_declared, + } +} + +pub(crate) fn type_covers_throw_fact( + pattern_ty: &Ty, + fact: &Ty, + aliases: &HashMap, +) -> bool { + if matches!( + fact, + Ty::Unknown { .. } | Ty::BuiltinUnknown { .. } | Ty::Error { .. } + ) { + let resolved = resolve_alias_chain(pattern_ty, aliases); + return matches!( + resolved, + Ty::Unknown { .. } | Ty::BuiltinUnknown { .. } | Ty::Error { .. } + ); + } + + normalize::is_subtype_of(fact, pattern_ty, aliases) +} + +pub fn combine_effect_vars_with_body_throws( + synthetic_effect_vars: &[Name], + body_throws_facts: BTreeSet, +) -> Ty { + let mut all_throws: Vec = synthetic_effect_vars + .iter() + .map(|v| Ty::TypeVar(v.clone(), TyAttr::default())) + .collect(); + all_throws.extend(body_throws_facts); + all_throws.retain(|t| !matches!(t, Ty::Never { .. } | Ty::Void { .. })); + + match all_throws.len() { + 0 => Ty::Never { + attr: TyAttr::default(), + }, + 1 => all_throws.remove(0), + _ => Ty::Union(all_throws, TyAttr::default()), + } +} + +pub fn concrete_throws_ty_from_facts(facts: BTreeSet) -> Ty { + let mut concrete = BTreeSet::new(); + for fact in facts { + if matches!(fact, Ty::Never { .. } | Ty::Void { .. } | Ty::TypeVar(_, _)) { + continue; + } + let widened = match fact { + Ty::Literal(lit, _, _) => { + Ty::Primitive(PrimitiveType::from_literal(&lit), TyAttr::default()) + } + other => other, + }; + concrete.insert(widened); + } + + match concrete.len() { + 0 => Ty::Never { + attr: TyAttr::default(), + }, + 1 => concrete.into_iter().next().unwrap_or(Ty::Never { + attr: TyAttr::default(), + }), + _ => Ty::Union(concrete.into_iter().collect(), TyAttr::default()), + } +} + +pub(crate) fn function_shape_matches_ignoring_outer_throws( + got: &Ty, + expected: &Ty, + aliases: &HashMap, +) -> bool { + let got_resolved = resolve_alias_chain(got, aliases); + let expected_resolved = resolve_alias_chain(expected, aliases); + + match (got_resolved, expected_resolved) { + ( + Ty::Function { + params: got_params, + ret: got_ret, + attr: got_attr, + .. + }, + Ty::Function { + params: expected_params, + ret: expected_ret, + attr: expected_attr, + .. + }, + ) => normalize::is_subtype_of( + &Ty::Function { + params: got_params, + ret: got_ret, + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), + attr: got_attr, + }, + &Ty::Function { + params: expected_params, + ret: expected_ret, + throws: Box::new(Ty::Never { + attr: TyAttr::default(), + }), + attr: expected_attr, + }, + aliases, + ), + _ => false, + } +} + +fn declared_covers_fact( + declared: &Ty, + fact: &Ty, + aliases: &HashMap, +) -> bool { + match fact { + Ty::Unknown { .. } | Ty::Error { .. } => { + let resolved = resolve_alias_chain(declared, aliases); + matches!( + resolved, + Ty::Unknown { .. } | Ty::BuiltinUnknown { .. } | Ty::Error { .. } + ) + } + Ty::BuiltinUnknown { .. } => { + let resolved = resolve_alias_chain(declared, aliases); + matches!( + resolved, + Ty::BuiltinUnknown { .. } | Ty::Unknown { .. } | Ty::Error { .. } + ) + } + _ => normalize::is_subtype_of(fact, declared, aliases), + } +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml b/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml index 587f615da8..0968af18a8 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml @@ -5,7 +5,7 @@ function compose( f: (A) -> B, g: (B) -> C ) -> (A) -> C { - (a: A) -> C { g(f(a)) } + return (a: A) -> C { g(f(a)) } } // Pure composition diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/returned_closures.baml b/baml_language/crates/baml_tests/projects/function_type_throws/returned_closures.baml index 1be16e1bb5..0c8f86f9bc 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/returned_closures.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/returned_closures.baml @@ -2,12 +2,12 @@ // Returning a pure closure - should be fine function make_pure() -> (() -> int) { - () -> int { 42 } + return () -> int { 42 } } // Returning a closure with explicit throws function make_thrower() -> (() -> int throws string) { - () -> int { throw "error" } + return () -> int { throw "error" } } // Using a returned closure diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml b/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml index f02b081d56..cb1260e7b9 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml @@ -36,6 +36,13 @@ function test_stored_local_error() -> null { null } +// ERROR: wrong function shape should report the structural mismatch, not the +// stored-throws diagnostic. +function test_stored_local_shape_mismatch() -> null { + let f: (int) -> null = () -> null { throw "oops" } + null +} + // OK: assigning pure lambda to variable typed throws never function test_stored_local_ok() -> null { let f: () -> null = () -> null { null } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap index f178332743..a6ef6a6fab 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap @@ -58,6 +58,7 @@ RParen ")" Arrow "->" Word "C" LBrace "{" +Return "return" LParen "(" Word "a" Colon ":" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__returned_closures.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__returned_closures.snap index 374c9cb45c..117341fc9f 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__returned_closures.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__returned_closures.snap @@ -37,6 +37,7 @@ Arrow "->" Word "int" RParen ")" LBrace "{" +Return "return" LParen "(" RParen ")" Arrow "->" @@ -67,6 +68,7 @@ Throws "throws" Word "string" RParen ")" LBrace "{" +Return "return" LParen "(" RParen ")" Arrow "->" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap index 5249184d9d..bcada89efc 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap @@ -227,6 +227,54 @@ Word "null" RBrace "}" Slash "/" Slash "/" +Word "ERROR" +Colon ":" +Word "wrong" +Function "function" +Word "shape" +Word "should" +Word "report" +Word "the" +Word "structural" +Word "mismatch" +Comma "," +Word "not" +Word "the" +Slash "/" +Slash "/" +Word "stored-throws" +Word "diagnostic" +Dot "." +Function "function" +Word "test_stored_local_shape_mismatch" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "f" +Colon ":" +LParen "(" +Word "int" +RParen ")" +Arrow "->" +Word "null" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Quote "\"" +Word "oops" +Quote "\"" +RBrace "}" +Word "null" +RBrace "}" +Slash "/" +Slash "/" Word "OK" Colon ":" Word "assigning" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__compose_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__compose_hof.snap index 732d5c9073..312cbfb0de 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__compose_hof.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__compose_hof.snap @@ -58,32 +58,34 @@ SOURCE_FILE EXPR_FUNCTION_BODY BLOCK_EXPR L_BRACE "{" - LAMBDA_EXPR - PARAMETER_LIST - L_PAREN "(" - PARAMETER - WORD "a" - COLON ":" - TYPE_EXPR "A" - WORD "A" - R_PAREN ")" - ARROW "->" - TYPE_EXPR "C" - WORD "C" - BLOCK_EXPR - L_BRACE "{" - CALL_EXPR - WORD "g" - CALL_ARGS - L_PAREN "(" - CALL_EXPR - WORD "f" - CALL_ARGS "(a)" - L_PAREN "(" - WORD "a" - R_PAREN ")" - R_PAREN ")" - R_BRACE "}" + RETURN_STMT + KW_RETURN "return" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "a" + COLON ":" + TYPE_EXPR "A" + WORD "A" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "C" + WORD "C" + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "g" + CALL_ARGS + L_PAREN "(" + CALL_EXPR + WORD "f" + CALL_ARGS "(a)" + L_PAREN "(" + WORD "a" + R_PAREN ")" + R_PAREN ")" + R_BRACE "}" R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__returned_closures.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__returned_closures.snap index 777f1272b9..cba8c163a5 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__returned_closures.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__returned_closures.snap @@ -23,17 +23,19 @@ SOURCE_FILE EXPR_FUNCTION_BODY BLOCK_EXPR L_BRACE "{" - LAMBDA_EXPR - PARAMETER_LIST "()" - L_PAREN "(" - R_PAREN ")" - ARROW "->" - TYPE_EXPR "int" - WORD "int" - BLOCK_EXPR "{ 42 }" - L_BRACE "{" - INTEGER_LITERAL "42" - R_BRACE "}" + RETURN_STMT + KW_RETURN "return" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ 42 }" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" @@ -59,23 +61,25 @@ SOURCE_FILE EXPR_FUNCTION_BODY BLOCK_EXPR L_BRACE "{" - LAMBDA_EXPR - PARAMETER_LIST "()" - L_PAREN "(" - R_PAREN ")" - ARROW "->" - TYPE_EXPR "int" - WORD "int" - BLOCK_EXPR - L_BRACE "{" - THROW_STMT - THROW_EXPR - KW_THROW "throw" - STRING_LITERAL "error" - QUOTE """ - WORD "error" - QUOTE """ - R_BRACE "}" + RETURN_STMT + KW_RETURN "return" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "error" + QUOTE """ + WORD "error" + QUOTE """ + R_BRACE "}" R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap index 03836d3b8a..d2b0649ece 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap @@ -182,6 +182,51 @@ SOURCE_FILE R_BRACE "}" WORD "null" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_stored_local_shape_mismatch" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "oops" + QUOTE """ + WORD "oops" + QUOTE """ + R_BRACE "}" + WORD "null" + R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" WORD "test_stored_local_ok" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index 9f0d8a87da..e649b3c3c7 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -56,7 +56,7 @@ function user.make_throwing_handler() -> user.ThrowingHandler [expr] { { } user.ThrowingHandler { run: () -> null { { throw "error" } } } } function user.compose(f: (user.A) -> user.B, g: (user.B) -> user.C) -> (user.A) -> user.C [expr] { - { } + { return (a: A) -> C { { } g(f(a)) } } } function user.test_compose_both_throw() -> (int) -> string [expr] { { } compose((x: int) -> int { { throw "string error" } }, (y: int) -> string { { throw 42 } }) @@ -70,6 +70,9 @@ function user.test_compose_pure() -> (int) -> string [expr] { function user.test_compose_second_throws() -> (int) -> string [expr] { { } compose((x: int) -> int { { } x Mul 2 }, (y: int) -> string { { throw "g failed" } }) } + +--- captures --- +lambda (a) in compose: captures [g, f] class user.ApiError { code: int message: string @@ -331,10 +334,10 @@ function user.test_nested_outer_throws() -> int [expr] { { let outer = () -> int { { let inner = () -> int { { } 42 }; throw "outer boom" } } } outer() } function user.make_pure() -> () -> int [expr] { - { } + { return () -> int { { } 42 } } } function user.make_thrower() -> () -> int [expr] { - { } + { return () -> int { { throw "error" } } } } function user.test_use_pure() -> int [expr] { { let f = make_pure() } f() @@ -379,6 +382,9 @@ function user.test_stored_local_error() -> null [expr] { function user.test_stored_local_ok() -> null [expr] { { let f: () -> null = () -> null { { } null } } null } +function user.test_stored_local_shape_mismatch() -> null [expr] { + { let f: (int) -> null = () -> null { { throw "oops" } } } null +} function user.test_stored_local_with_throws() -> null [expr] { { let f: () -> null = () -> null { { throw "oops" } } } null } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index bd1a35c0b1..6b97a629ec 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -423,11 +423,11 @@ fn .() -> null { fn user.compose(f: (void) -> void, g: (void) -> void) -> (void) -> void { // Locals: let _0: (void) -> void // _0 // return - let _1: (void) -> void // f // param - let _2: (void) -> void // g // param + let _1: (void) -> void // f // param [captured] + let _2: (void) -> void // g // param [captured] bb0: { - _0 = const null; + _0 = make_closure lambda[0](copy _2, copy _1); goto -> bb1; } @@ -436,6 +436,34 @@ fn user.compose(f: (void) -> void, g: (void) -> void) -> (void) -> void { } } +// lambda[0] +fn .(a: void) -> null { + // Locals: + let _0: null // _0 // return + let _1: void // a // param + let _2: void + let _3: void + let _4: void + + bb0: { + _2 = copy capture[0]; + _4 = copy capture[1]; + _3 = call copy _4(copy _1) -> [bb1]; + } + + bb1: { + _0 = call copy _2(copy _3) -> [bb2]; + } + + bb2: { + goto -> bb3; + } + + bb3: { + return; + } +} + fn user.test_compose_both_throw() -> (int) -> string { // Locals: let _0: (int) -> string // _0 // return @@ -3096,7 +3124,22 @@ fn user.make_pure() -> () -> int { let _0: () -> int // _0 // return bb0: { - _0 = const null; + _0 = make_closure lambda[0](); + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const 42_i64; goto -> bb1; } @@ -3110,7 +3153,7 @@ fn user.make_thrower() -> () -> int { let _0: () -> int // _0 // return bb0: { - _0 = const null; + _0 = make_closure lambda[0](); goto -> bb1; } @@ -3119,6 +3162,16 @@ fn user.make_thrower() -> () -> int { } } +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "error"; + } +} + fn user.test_use_pure() -> int { // Locals: let _0: int // _0 // return @@ -3433,6 +3486,30 @@ fn .() -> null { } } +fn user.test_stored_local_shape_mismatch() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "oops"; + } +} + fn user.test_stored_local_with_throws() -> null { // Locals: let _0: null // _0 // return diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 87e37ba59f..5b63867b30 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -142,9 +142,15 @@ class user.MixedHandler$stream { risky: unknown } function user.compose(f: (A) -> B throws __throws_f, g: (B) -> C throws __throws_g) -> (A) -> C throws __throws_f | __throws_g { - { : (A) -> C + { : never + return : (a: A) -> C + (a: A) -> C { ... } : (a: A) -> C + { + g(f(a)) + } } - !! 193..223: missing return: expected `(A) -> C` +} +lambda user.compose { } function user.test_compose_pure() -> (int) -> string throws never { { : (int) -> string | "result" @@ -235,8 +241,6 @@ function user.throw_enum_variant(x: int) -> string throws user.ErrorKind { } "ok" : "ok" } - !! 411..421: throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized - ?? 411..421: extraneous throws declaration: user.ErrorKind } function user.test_lambda_throws_enum() -> int throws user.ErrorKind.ValidationFailed { { : never @@ -283,8 +287,6 @@ function user.throw_various_errors(x: int) -> string throws user.ErrorKind { _ => "ok" : "ok" } - !! 1204..1214: throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized, user.ErrorKind.ValidationFailed - ?? 1204..1214: extraneous throws declaration: user.ErrorKind } function user.throw_mixed_classes(x: int) -> string throws user.ApiError | user.ValidationError { { : "ok" @@ -311,8 +313,6 @@ function user.throw_enum_or_class(x: int) -> string throws user.ErrorKind | user } "ok" : "ok" } - !! 1729..1750: throws contract violation: `user.ErrorKind | user.ApiError` is missing user.ErrorKind.RateLimited - ?? 1729..1750: extraneous throws declaration: user.ErrorKind } function user.throw_any_error(x: int) -> string throws string | user.ErrorKind | user.ApiError { { : "ok" @@ -326,8 +326,6 @@ function user.throw_any_error(x: int) -> string throws string | user.ErrorKind | _ => "ok" : "ok" } - !! 1968..1998: throws contract violation: `string | user.ErrorKind | user.ApiError` is missing user.ErrorKind.NotFound - ?? 1968..1998: extraneous throws declaration: user.ErrorKind } function user.apply_may_throw_enum(f: () -> int throws user.ErrorKind) -> int throws user.ErrorKind { { : int @@ -404,8 +402,6 @@ function user.catch_enum_variants(x: int) -> string throws never { _ => "other error" : "other error" } - ?? 3367..3381: unreachable arm - ?? 3392..3405: unreachable arm } function user.catch_class_error(x: int) -> string throws never { { : string @@ -426,7 +422,6 @@ function user.catch_mixed_errors(x: int) -> string throws never { _ => "unknown" : "unknown" } - ?? 3721..3730: unreachable arm } class user.EnumThrowingHandler { run: () -> null throws user.ErrorKind @@ -454,22 +449,22 @@ lambda user.make_class_handler { type user.EnumThrower = () -> int throws user.ErrorKind type user.ClassThrower = () -> int throws user.ApiError type user.MixedThrower = () -> int throws user.ErrorKind | user.ApiError | string -function user.use_enum_thrower(f: user.EnumThrower) -> int throws never { +function user.use_enum_thrower(f: user.EnumThrower) -> int throws user.ErrorKind { { : int f() : int } } -function user.use_class_thrower(f: user.ClassThrower) -> int throws never { +function user.use_class_thrower(f: user.ClassThrower) -> int throws user.ApiError { { : int f() : int } } -function user.use_mixed_thrower(f: user.MixedThrower) -> int throws never { +function user.use_mixed_thrower(f: user.MixedThrower) -> int throws user.ApiError | user.ErrorKind | string { { : int f() : int } } -function user.test_type_alias_enum() -> int throws never { +function user.test_type_alias_enum() -> int throws user.ErrorKind { { : int use_enum_thrower(() -> int { ... }) : int () -> int { ... } : () -> never throws user.ErrorKind.ValidationFailed @@ -480,7 +475,7 @@ function user.test_type_alias_enum() -> int throws never { } lambda user.test_type_alias_enum { } -function user.test_type_alias_class() -> int throws never { +function user.test_type_alias_class() -> int throws user.ApiError { { : int use_class_thrower(() -> int { ... }) : int () -> int { ... } : () -> never throws user.ApiError @@ -491,7 +486,7 @@ function user.test_type_alias_class() -> int throws never { } lambda user.test_type_alias_class { } -function user.test_type_alias_mixed() -> int throws never { +function user.test_type_alias_mixed() -> int throws user.ApiError | user.ErrorKind | string { { : int use_mixed_thrower(() -> int { ... }) : int () -> int { ... } : () -> never throws string @@ -602,7 +597,7 @@ type user.PureCallback$stream = unknown type user.ExplicitPure$stream = unknown type user.Mapper$stream = unknown type user.Wrapper$stream = unknown -function user.apply_guarded(f: () -> int throws __throws_f) -> int throws string | __throws_f { +function user.apply_guarded(f: () -> int throws __throws_f) -> int throws __throws_f | string { { : int let result = f() : int if (result < 0 : bool) : void @@ -792,7 +787,7 @@ function user.test_apply_explicit_throws() -> int throws string { } lambda user.test_apply_explicit_throws { } -function user.apply_and_throw(f: () -> int throws __throws_f) -> int throws string | __throws_f { +function user.apply_and_throw(f: () -> int throws __throws_f) -> int throws __throws_f | string { { : int let result = f() : int if (result < 0 : bool) : void @@ -818,7 +813,7 @@ function user.helper_with_body_throw() -> int throws string { throw "helper boom" : "helper boom" } } -function user.apply_with_helper(f: () -> int throws __throws_f) -> int throws string | __throws_f { +function user.apply_with_helper(f: () -> int throws __throws_f) -> int throws __throws_f | string { { : int helper_with_body_throw() : int f() : int @@ -1068,14 +1063,26 @@ lambda user.test_nested_both_throw { lambda user.test_nested_both_throw { } function user.make_pure() -> () -> int throws never { - { : () -> int + { : never + return : () -> 42 + () -> int { ... } : () -> 42 + { + 42 + } } - !! 142..165: missing return: expected `() -> int` +} +lambda user.make_pure { } function user.make_thrower() -> () -> int throws string throws never { - { : () -> int throws string + { : never + return : () -> never throws string + () -> int { ... } : () -> never throws string + { + throw "error" + } } - !! 263..297: missing return: expected `() -> int throws string` +} +lambda user.make_thrower { } function user.test_use_pure() -> int throws never { { : int @@ -1130,6 +1137,19 @@ function user.test_stored_local_error() -> null throws never { } lambda user.test_stored_local_error { } +function user.test_stored_local_shape_mismatch() -> null throws never { + { : null + let f = : () -> never throws string + () -> null { ... } : () -> never throws string + { + throw "oops" + } + null : null + } + !! 1283..1311: expected 1 argument(s), got 0 +} +lambda user.test_stored_local_shape_mismatch { +} function user.test_stored_local_ok() -> null throws never { { : null let f = : () -> null @@ -1164,7 +1184,7 @@ function user.test_stored_alias_error() -> null throws never { } null : null } - !! 1651..1679: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + !! 1875..1903: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type } lambda user.test_stored_alias_error { } @@ -1188,7 +1208,7 @@ function user.make_stored_closure_bad() -> () -> int throws never { throw "oops" } } - !! 1980..2007: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + !! 2204..2231: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type } lambda user.make_stored_closure_bad { } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index cc5e7bd3d2..8d838111aa 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -2,128 +2,6 @@ source: crates/baml_tests/src/generated_tests.rs --- === COMPILER2 DIAGNOSTICS === - [type] Error: missing return: expected `(A) -> C` - ╭─[ compose_hof.baml:7:14 ] - │ - 7 │ ╭─▶ ) -> (A) -> C { - ┆ ┆ - 9 │ ├─▶ } - │ │ - │ ╰─────── missing return: expected `(A) -> C` - │ - │ Note: Error code: E0001 -───╯ - - [type] Warning: extraneous throws declaration: user.ErrorKind - ╭─[ enum_class_throws.baml:27:53 ] - │ - 27 │ function throw_enum_variant(x: int) -> string throws ErrorKind { - │ ─────┬──── - │ ╰────── extraneous throws declaration: user.ErrorKind - │ - │ Note: Error code: E0001 -────╯ - - [type] Error: throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized - ╭─[ enum_class_throws.baml:27:53 ] - │ - 27 │ function throw_enum_variant(x: int) -> string throws ErrorKind { - │ ─────┬──── - │ ╰────── throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized - │ - │ Note: Error code: E0001 -────╯ - - [type] Warning: extraneous throws declaration: user.ErrorKind - ╭─[ enum_class_throws.baml:60:55 ] - │ - 60 │ function throw_various_errors(x: int) -> string throws ErrorKind { - │ ─────┬──── - │ ╰────── extraneous throws declaration: user.ErrorKind - │ - │ Note: Error code: E0001 -────╯ - - [type] Error: throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized, user.ErrorKind.ValidationFailed - ╭─[ enum_class_throws.baml:60:55 ] - │ - 60 │ function throw_various_errors(x: int) -> string throws ErrorKind { - │ ─────┬──── - │ ╰────── throws contract violation: `user.ErrorKind` is missing user.ErrorKind.NotFound, user.ErrorKind.Unauthorized, user.ErrorKind.ValidationFailed - │ - │ Note: Error code: E0001 -────╯ - - [type] Warning: extraneous throws declaration: user.ErrorKind - ╭─[ enum_class_throws.baml:81:54 ] - │ - 81 │ function throw_enum_or_class(x: int) -> string throws ErrorKind | ApiError { - │ ──────────┬────────── - │ ╰──────────── extraneous throws declaration: user.ErrorKind - │ - │ Note: Error code: E0001 -────╯ - - [type] Error: throws contract violation: `user.ErrorKind | user.ApiError` is missing user.ErrorKind.RateLimited - ╭─[ enum_class_throws.baml:81:54 ] - │ - 81 │ function throw_enum_or_class(x: int) -> string throws ErrorKind | ApiError { - │ ──────────┬────────── - │ ╰──────────── throws contract violation: `user.ErrorKind | user.ApiError` is missing user.ErrorKind.RateLimited - │ - │ Note: Error code: E0001 -────╯ - - [type] Warning: extraneous throws declaration: user.ErrorKind - ╭─[ enum_class_throws.baml:88:50 ] - │ - 88 │ function throw_any_error(x: int) -> string throws string | ErrorKind | ApiError { - │ ───────────────┬────────────── - │ ╰──────────────── extraneous throws declaration: user.ErrorKind - │ - │ Note: Error code: E0001 -────╯ - - [type] Error: throws contract violation: `string | user.ErrorKind | user.ApiError` is missing user.ErrorKind.NotFound - ╭─[ enum_class_throws.baml:88:50 ] - │ - 88 │ function throw_any_error(x: int) -> string throws string | ErrorKind | ApiError { - │ ───────────────┬────────────── - │ ╰──────────────── throws contract violation: `string | user.ErrorKind | user.ApiError` is missing user.ErrorKind.NotFound - │ - │ Note: Error code: E0001 -────╯ - - [type] Warning: unreachable arm - ╭─[ enum_class_throws.baml:140:31 ] - │ - 140 │ ErrorKind.Unauthorized => "unauthorized", - │ ───────┬────── - │ ╰──────── unreachable arm - │ - │ Note: Error code: E0001 -─────╯ - - [type] Warning: unreachable arm - ╭─[ enum_class_throws.baml:141:10 ] - │ - 141 │ _ => "other error" - │ ──────┬────── - │ ╰──────── unreachable arm - │ - │ Note: Error code: E0001 -─────╯ - - [type] Warning: unreachable arm - ╭─[ enum_class_throws.baml:155:10 ] - │ - 155 │ _ => "unknown" - │ ────┬──── - │ ╰────── unreachable arm - │ - │ Note: Error code: E0001 -─────╯ - [validation] Error: Duplicate function `may_fail` ╭─[ catch_absorbs_throws.baml:3:10 ] │ @@ -212,30 +90,6 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ - [type] Error: missing return: expected `() -> int` - ╭─[ returned_closures.baml:4:36 ] - │ - 4 │ ╭─▶ function make_pure() -> (() -> int) { - ┆ ┆ - 6 │ ├─▶ } - │ │ - │ ╰─────── missing return: expected `() -> int` - │ - │ Note: Error code: E0001 -───╯ - - [type] Error: missing return: expected `() -> int throws string` - ╭─[ returned_closures.baml:9:53 ] - │ - 9 │ ╭─▶ function make_thrower() -> (() -> int throws string) { - ┆ ┆ - 11 │ ├─▶ } - │ │ - │ ╰─────── missing return: expected `() -> int throws string` - │ - │ Note: Error code: E0001 -────╯ - [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type ╭─[ stored_callback_enforcement.baml:12:27 ] │ @@ -256,10 +110,20 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ + [type] Error: expected 1 argument(s), got 0 + ╭─[ stored_callback_enforcement.baml:42:25 ] + │ + 42 │ let f: (int) -> null = () -> null { throw "oops" } + │ ──────────────┬───────────── + │ ╰─────────────── expected 1 argument(s), got 0 + │ + │ Note: Error code: E0001 +────╯ + [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type - ╭─[ stored_callback_enforcement.baml:57:25 ] + ╭─[ stored_callback_enforcement.baml:64:25 ] │ - 57 │ let cb: StoredPureCb = () -> int { throw "error" } + 64 │ let cb: StoredPureCb = () -> int { throw "error" } │ ──────────────┬───────────── │ ╰─────────────── function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type │ @@ -267,9 +131,9 @@ source: crates/baml_tests/src/generated_tests.rs ────╯ [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type - ╭─[ stored_callback_enforcement.baml:71:9 ] + ╭─[ stored_callback_enforcement.baml:78:9 ] │ - 71 │ return () -> int { throw "oops" } + 78 │ return () -> int { throw "oops" } │ ─────────────┬───────────── │ ╰─────────────── function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type │ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 14b4230d46..1bbe3effc3 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -229,7 +229,15 @@ function user.catch_mixed_errors(x: int) -> string { } function user.compose(f: (void) -> void, g: (void) -> void) -> (void) -> void { - load_const null + load_var ?1 + make_cell + store_var ?1 + load_var ?2 + make_cell + store_var ?2 + load_var g + load_var f + make_closure ., 2 return } @@ -271,7 +279,7 @@ function user.make_good_stored_handler() -> StoredPureHandler { } function user.make_pure() -> () -> int { - load_const null + make_closure ., 0 return } @@ -307,7 +315,7 @@ function user.make_stored_throwing_handler() -> StoredThrowingHandler { } function user.make_thrower() -> () -> int { - load_const null + make_closure ., 0 return } @@ -704,6 +712,11 @@ function user.test_stored_local_ok() -> null { return } +function user.test_stored_local_shape_mismatch() -> null { + load_const null + return +} + function user.test_stored_local_with_throws() -> null { load_const null return diff --git a/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs b/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs index fb8b02dc2b..00fefc9309 100644 --- a/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs +++ b/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs @@ -42,6 +42,10 @@ pub(crate) mod support { lower_type_expr::{ FnTypeLoweringContext, lower_type_expr_in_ns, lower_type_expr_with_fn_context, }, + throws_semantics::{ + combine_effect_vars_with_body_throws, concrete_throws_ty_from_facts, + flatten_ty_to_facts, + }, }; use baml_project::ProjectDatabase; @@ -1077,22 +1081,18 @@ pub(crate) mod support { }) .unwrap_or_else(|| "?".into()); // Compute inferred throws from transitive throw set - let inferred_throws: Option = { + let inferred_throws: Option = { let key = baml_base::Name::new(&*fqn); throw_sets .transitive_for(&key) .filter(|facts| !facts.is_empty()) - .map(|facts| { - let types: Vec = - facts.iter().map(|f| f.to_string()).collect(); - types.join(" | ") - }) + .map(|facts| concrete_throws_ty_from_facts(facts.clone())) }; // Compute post-inference effective throws from the function body. // This captures HOF effect propagation that the pre-inference // throw_sets cannot see. - let post_inference_throws: Option = { + let post_inference_throws: Option = { if let Some(ref fb) = func_body_opt { if let baml_compiler2_hir::body::FunctionBody::Expr(ref body) = **fb { @@ -1100,7 +1100,7 @@ pub(crate) mod support { if matches!(ty, baml_compiler2_tir::ty::Ty::Never { .. }) { None } else { - Some(ty.to_string()) + Some(ty) } } else { None @@ -1121,50 +1121,25 @@ pub(crate) mod support { None => format!(" throws {declared}"), } } else if !synthetic_display_vars.is_empty() { - // Function has implicit effect vars from callback params. - // Union the effect vars with the body's own concrete throws. - let mut all_throws: Vec = synthetic_display_vars - .iter() - .map(|v| v.to_string()) - .collect(); - // Add body's own throws (post_inference_throws excludes TypeVars, - // so these are the function's own concrete throws). - if let Some(ref body_throws) = post_inference_throws { - // Split the body throws string and add each component. - for component in body_throws.split(" | ") { - let trimmed = component.trim(); - if !trimmed.is_empty() - && !all_throws.contains(&trimmed.to_string()) - { - all_throws.push(trimmed.to_string()); - } - } - } - // Sort for deterministic output: effect vars first, then others. - all_throws.sort_by(|a, b| { - let a_is_effect = a.starts_with("__throws_"); - let b_is_effect = b.starts_with("__throws_"); - match (a_is_effect, b_is_effect) { - (true, false) => std::cmp::Ordering::Greater, - (false, true) => std::cmp::Ordering::Less, - _ => a.cmp(b), - } - }); - let throws_str = all_throws.join(" | "); + let body_facts = post_inference_throws + .as_ref() + .map(flatten_ty_to_facts) + .unwrap_or_default(); + let throws_ty = combine_effect_vars_with_body_throws( + &synthetic_display_vars, + body_facts, + ); match &inferred_throws { Some(inferred) => { - format!(" throws {throws_str} infers {inferred}") + format!(" throws {throws_ty} infers {inferred}") } - None => format!(" throws {throws_str}"), + None => format!(" throws {throws_ty}"), } } else { // No explicit throws, no effect vars. // Use post-inference effective throws if available, // falling back to pre-inference transitive throws. - match post_inference_throws - .as_deref() - .or(inferred_throws.as_deref()) - { + match post_inference_throws.as_ref().or(inferred_throws.as_ref()) { Some(inferred) => format!(" throws {inferred}"), None => " throws never".to_string(), } From bab3306fe7fe885fba38eb318191123cfe550c06 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 11:29:16 -0500 Subject: [PATCH 07/26] Fix function-type throws validation and optional-call propagation --- .../baml_compiler2_ast/src/disambiguate.rs | 10 +- .../crates/baml_compiler2_mir/src/lower.rs | 28 +- .../crates/baml_compiler2_tir/src/builder.rs | 198 +++++---- .../src/effective_throws.rs | 22 +- .../src/throws_semantics.rs | 15 +- .../crates/baml_lsp2_actions/src/utils.rs | 28 +- .../array_map_throws.baml | 4 +- .../catch_absorbs_throws.baml | 10 +- .../function_type_throws/fn_decl_throws.baml | 6 +- .../fn_type_alias_throws.baml | 8 + .../function_type_throws/hof_rethrows.baml | 4 +- .../lambda_throws_violation.baml | 12 + .../optional_call_throws.baml | 15 + ...s__catch_return_type_mismatch__04_tir.snap | 2 +- ...pe_throws__01_lexer__array_map_throws.snap | 4 +- ...hrows__01_lexer__catch_absorbs_throws.snap | 10 +- ...type_throws__01_lexer__fn_decl_throws.snap | 6 +- ...hrows__01_lexer__fn_type_alias_throws.snap | 42 ++ ...n_type_throws__01_lexer__hof_rethrows.snap | 4 +- ...ws__01_lexer__lambda_throws_violation.snap | 90 +++++ ...hrows__01_lexer__optional_call_throws.snap | 94 +++++ ...e_throws__02_parser__array_map_throws.snap | 4 +- ...rows__02_parser__catch_absorbs_throws.snap | 10 +- ...ype_throws__02_parser__fn_decl_throws.snap | 6 +- ...rows__02_parser__fn_type_alias_throws.snap | 43 ++ ..._type_throws__02_parser__hof_rethrows.snap | 4 +- ...s__02_parser__lambda_throws_violation.snap | 98 +++++ ...rows__02_parser__optional_call_throws.snap | 134 +++++++ ...l_tests__function_type_throws__03_hir.snap | 44 +- ...tests__function_type_throws__04_5_mir.snap | 379 +++++++++++++----- ...l_tests__function_type_throws__04_tir.snap | 97 ++++- ..._function_type_throws__05_diagnostics.snap | 108 ++--- ...sts__function_type_throws__06_codegen.snap | 154 +++++-- ...hrows__10_formatter__array_map_throws.snap | 4 +- ...s__10_formatter__catch_absorbs_throws.snap | 2 +- ..._throws__10_formatter__fn_decl_throws.snap | 2 +- ...pe_throws__10_formatter__hof_rethrows.snap | 4 +- ...10_formatter__lambda_throws_violation.snap | 11 +- ...s__10_formatter__optional_call_throws.snap | 5 + baml_language/crates/baml_type/src/lib.rs | 51 ++- 40 files changed, 1413 insertions(+), 359 deletions(-) create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/optional_call_throws.baml create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__optional_call_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__optional_call_throws.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__optional_call_throws.snap diff --git a/baml_language/crates/baml_compiler2_ast/src/disambiguate.rs b/baml_language/crates/baml_compiler2_ast/src/disambiguate.rs index 0a2bb7f8cc..539988648d 100644 --- a/baml_language/crates/baml_compiler2_ast/src/disambiguate.rs +++ b/baml_language/crates/baml_compiler2_ast/src/disambiguate.rs @@ -138,11 +138,19 @@ fn validate_type_expr_tree(expr: &TypeExpr, diagnostics: &mut Vec<(String, text_ validate_type_expr_tree(v, diagnostics); } } - TypeExpr::Function { params, ret, .. } => { + TypeExpr::Function { + params, + ret, + throws, + .. + } => { for p in params { validate_type_expr_tree(&p.ty, diagnostics); } validate_type_expr_tree(ret, diagnostics); + if let Some(throws) = throws { + validate_type_expr_tree(throws, diagnostics); + } } _ => {} } diff --git a/baml_language/crates/baml_compiler2_mir/src/lower.rs b/baml_language/crates/baml_compiler2_mir/src/lower.rs index d230fe8273..be52deb94b 100644 --- a/baml_language/crates/baml_compiler2_mir/src/lower.rs +++ b/baml_language/crates/baml_compiler2_mir/src/lower.rs @@ -170,18 +170,22 @@ pub fn convert_tir2_ty(ty: &Tir2Ty, resolved: &ResolvedAliases) -> Ty { ret, throws, attr, - } => Ty::Function { - params: params - .iter() - .map(|(_, t)| convert_tir2_ty(t, resolved)) - .collect(), - ret: Box::new(convert_tir2_ty(ret, resolved)), - throws: match throws.as_ref() { - Tir2Ty::Never { .. } => None, - t => Some(Box::new(convert_tir2_ty(t, resolved))), - }, - attr: attr.clone(), - }, + } => { + let converted_throws = convert_tir2_ty(throws, resolved); + Ty::Function { + params: params + .iter() + .map(|(_, t)| convert_tir2_ty(t, resolved)) + .collect(), + ret: Box::new(convert_tir2_ty(ret, resolved)), + throws: if matches!(converted_throws, Ty::Void { .. }) { + None + } else { + Some(Box::new(converted_throws)) + }, + attr: attr.clone(), + } + } // Bottom / sentinel types Tir2Ty::Never { attr } => Ty::Void { attr: attr.clone() }, diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index 513b9a0a52..af4eb0e392 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -718,38 +718,15 @@ impl<'db> TypeInferenceBuilder<'db> { }); // Infer the lambda body using save/restore approach - let (ret_ty, e_body, _lambda_expressions) = self.infer_lambda_body( + let (ret_ty, effective_throws_ty, _lambda_expressions) = + self.infer_lambda_body(func_def, ¶m_tys, return_annotation.as_ref()); + self.finalize_lambda_function_ty( func_def, - ¶m_tys, - return_annotation.as_ref(), - expr_id, - ); - // Determine throws: explicit annotation takes precedence, - // otherwise infer from the body. - let throws_ty = if let Some(te) = &func_def.throws { - let mut diags = Vec::new(); - let declared = crate::lower_type_expr::lower_type_expr_in_ns( - self.context.db(), - &te.expr, - self.package_items, - &self.ns_context, - &self.generic_params, - &mut diags, - ); - for diag in diags { - self.context.report_at_span(diag, te.span); - } - declared - } else { - e_body - }; - - Ty::Function { - params: param_tys, - ret: Box::new(ret_ty), - throws: Box::new(throws_ty), - attr: TyAttr::default(), - } + param_tys, + ret_ty, + effective_throws_ty, + &all_generic_params, + ) } Expr::Missing => Ty::Unknown { attr: TyAttr::default(), @@ -1109,7 +1086,6 @@ impl<'db> TypeInferenceBuilder<'db> { Ty::Function { params: expected_params, ret: expected_ret, - throws: expected_throws, .. } => { // Checking mode: decompose expected function type. @@ -1124,6 +1100,9 @@ impl<'db> TypeInferenceBuilder<'db> { ); } + let mut all_generic_params = self.generic_params.clone(); + all_generic_params.extend(func_def.generic_params.iter().cloned()); + // Determine param types: annotation takes precedence, else use expected let mut param_tys: Vec<(Option, Ty)> = Vec::new(); for (i, param) in func_def.params.iter().enumerate() { @@ -1142,7 +1121,7 @@ impl<'db> TypeInferenceBuilder<'db> { &te.expr, self.package_items, &self.ns_context, - &self.generic_params, + &all_generic_params, &mut diags, ); for diag in diags { @@ -1177,7 +1156,7 @@ impl<'db> TypeInferenceBuilder<'db> { &te.expr, self.package_items, &self.ns_context, - &self.generic_params, + &all_generic_params, &mut diags, ); for diag in diags { @@ -1189,56 +1168,20 @@ impl<'db> TypeInferenceBuilder<'db> { return_annotation.as_ref().unwrap_or(expected_ret.as_ref()); // Infer/check the lambda body using save/restore approach - let (ret_ty, e_body, _lambda_expressions) = self.infer_lambda_body( + let (ret_ty, effective_throws_ty, _lambda_expressions) = + self.infer_lambda_body(func_def, ¶m_tys, Some(effective_ret)); + let result = self.finalize_lambda_function_ty( func_def, - ¶m_tys, - Some(effective_ret), - expr_id, + param_tys, + ret_ty, + effective_throws_ty, + &all_generic_params, ); - // Determine throws: explicit annotation > expected > inferred from body - let throws_ty = if let Some(te) = &func_def.throws { - let mut diags = Vec::new(); - let declared = crate::lower_type_expr::lower_type_expr_in_ns( - self.context.db(), - &te.expr, - self.package_items, - &self.ns_context, - &self.generic_params, - &mut diags, - ); - for diag in diags { - self.context.report_at_span(diag, te.span); - } - declared - } else { - // No annotation: infer from body, use expected_throws for TypeVar binding - let _ = expected_throws; - e_body - }; - - let result = Ty::Function { - params: param_tys, - ret: Box::new(ret_ty), - throws: Box::new(throws_ty.clone()), - attr: TyAttr::default(), - }; - // Check throws compatibility: if expected is `throws never` but - // lambda body throws, report the stored function diagnostic - if matches!(expected_throws.as_ref(), Ty::Never { .. }) - && !matches!(throws_ty, Ty::Never { .. }) - && throws_semantics::function_shape_matches_ignoring_outer_throws( - &result, - &expected_resolved, - &self.aliases, - ) + if !crate::generics::contains_typevar(expected) + && !self.is_subtype(&result, expected) { - self.context.report_simple( - TirTypeError::StoredFunctionRequiresExplicitThrows { - actual_throws: throws_ty, - }, - expr_id, - ); + self.report_type_mismatch(expected, &result, expr_id); } self.record_expr_type(expr_id, result.clone()); @@ -1865,7 +1808,68 @@ impl<'db> TypeInferenceBuilder<'db> { } let effective = self.collect_effective_throws(body); - let diff = throws_semantics::throws_contract_diff(&declared_ty, &effective, &self.aliases); + self.report_throws_contract_diff_at_span(&declared_ty, &effective, span); + } + + fn lower_lambda_throws_annotation( + &mut self, + throws_annotation: &baml_compiler2_ast::SpannedTypeExpr, + generic_params: &[Name], + ) -> Ty { + let mut diags = Vec::new(); + let declared = crate::lower_type_expr::lower_type_expr_in_ns( + self.context.db(), + &throws_annotation.expr, + self.package_items, + &self.ns_context, + generic_params, + &mut diags, + ); + for diag in diags { + self.context.report_at_span(diag, throws_annotation.span); + } + declared + } + + fn finalize_lambda_function_ty( + &mut self, + func_def: &baml_compiler2_ast::FunctionDef, + params: Vec<(Option, Ty)>, + ret_ty: Ty, + effective_throws_ty: Ty, + generic_params: &[Name], + ) -> Ty { + let throws_ty = if let Some(throws_annotation) = &func_def.throws { + let declared_throws_ty = + self.lower_lambda_throws_annotation(throws_annotation, generic_params); + let effective_throws_facts = + throws_semantics::flatten_ty_to_facts(&effective_throws_ty); + self.report_throws_contract_diff_at_span( + &declared_throws_ty, + &effective_throws_facts, + throws_annotation.span, + ); + declared_throws_ty + } else { + effective_throws_ty + }; + + Ty::Function { + params, + ret: Box::new(ret_ty), + throws: Box::new(throws_ty), + attr: TyAttr::default(), + } + } + + fn report_throws_contract_diff_at_span( + &mut self, + declared_ty: &Ty, + effective_facts: &BTreeSet, + span: TextRange, + ) { + let diff = + throws_semantics::throws_contract_diff(declared_ty, effective_facts, &self.aliases); let mut extra: Vec = diff .uncovered_effective @@ -1883,7 +1887,7 @@ impl<'db> TypeInferenceBuilder<'db> { if !extra.is_empty() { self.context.report_at_span( TirTypeError::ThrowsContractViolation { - declared: declared_ty, + declared: declared_ty.clone(), extra_types: extra, }, span, @@ -2619,6 +2623,30 @@ impl<'db> TypeInferenceBuilder<'db> { for arg in args { self.collect_throw_facts_from_expr(*arg, body, out); } + let type_level_facts = self + .expressions + .get(callee) + .and_then(|ty| throws_semantics::function_throws_facts(ty, &self.aliases)) + .and_then(|facts| if facts.is_empty() { None } else { Some(facts) }); + if let Some(facts) = type_level_facts { + out.extend(facts); + } else if let Some(target) = self.call_target_name(*callee, body) { + let throws = crate::throw_inference::function_throw_sets( + self.context.db(), + self.package_id, + ); + if let Some(transitive) = throws.transitive_for(&target) { + out.extend(transitive.iter().cloned()); + } else { + out.insert(Ty::Unknown { + attr: TyAttr::default(), + }); + } + } else { + out.insert(Ty::Unknown { + attr: TyAttr::default(), + }); + } } Expr::Catch { base, .. } => { self.collect_throw_facts_from_expr(*base, body, out); @@ -2911,6 +2939,11 @@ impl<'db> TypeInferenceBuilder<'db> { } }); + // This helper reconstructs function types during lookup/inference rather than + // at the defining declaration site, so any lowering diagnostics must stay + // attached to the original signature instead of being re-emitted at each use. + drop(diags); + Ty::Function { params, ret: Box::new(ret_ty), @@ -4983,7 +5016,6 @@ impl<'db> TypeInferenceBuilder<'db> { func_def: &baml_compiler2_ast::FunctionDef, param_tys: &[(Option, Ty)], expected_ret: Option<&Ty>, - _lambda_expr_id: ExprId, ) -> (Ty, Ty, FxHashMap) { use baml_compiler2_ast::FunctionBodyDef; diff --git a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs index ef2aaaea2f..3e597fdb24 100644 --- a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs +++ b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs @@ -440,6 +440,19 @@ fn collect_effective_throws_from_expr<'db>( out, ); } + collect_effective_throws_from_call( + db, + package_id, + *callee, + body, + expressions, + aliases, + CallResolutionOptions { + include_typevars, + unknown_on_unresolved_call, + }, + out, + ); } Expr::OptionalChain { expr } => { collect_effective_throws_from_expr( @@ -657,7 +670,7 @@ fn collect_effective_throws_from_call<'db>( .get(&callee_expr_id) .and_then(|ty| function_throws_facts(ty, aliases)); - if let Some(facts) = type_level_facts { + if let Some(facts) = type_level_facts.as_ref() { let filtered: BTreeSet = facts .iter() .filter(|fact| options.include_typevars || !matches!(fact, Ty::TypeVar(_, _))) @@ -680,12 +693,7 @@ fn collect_effective_throws_from_call<'db>( } } - if options.unknown_on_unresolved_call - && expressions - .get(&callee_expr_id) - .and_then(|ty| function_throws_facts(ty, aliases)) - .is_none() - { + if options.unknown_on_unresolved_call && type_level_facts.is_none() { out.insert(Ty::Unknown { attr: TyAttr::default(), }); diff --git a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs index d352ff75fc..5a7151fb38 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs @@ -31,12 +31,25 @@ pub(crate) fn function_throws_facts( ty: &Ty, aliases: &HashMap, ) -> Option> { - match resolve_alias_chain(ty, aliases) { + match resolve_callable_function_ty(ty, aliases)? { Ty::Function { throws, .. } => Some(flatten_ty_to_facts(&throws)), _ => None, } } +fn resolve_callable_function_ty(ty: &Ty, aliases: &HashMap) -> Option { + let mut resolved = resolve_alias_chain(ty, aliases); + loop { + match resolved { + Ty::Function { .. } => return Some(resolved), + Ty::Optional(inner, _) => { + resolved = resolve_alias_chain(inner.as_ref(), aliases); + } + _ => return None, + } + } +} + /// Flatten a compound `Ty` into its leaf throw facts. /// Unions and optionals are decomposed; leaf types are kept as-is. pub fn flatten_ty_to_facts(ty: &Ty) -> BTreeSet { diff --git a/baml_language/crates/baml_lsp2_actions/src/utils.rs b/baml_language/crates/baml_lsp2_actions/src/utils.rs index 8cf334f2bf..16a4c5be88 100644 --- a/baml_language/crates/baml_lsp2_actions/src/utils.rs +++ b/baml_language/crates/baml_lsp2_actions/src/utils.rs @@ -150,7 +150,12 @@ pub fn display_ty(ty: &Ty) -> String { } Ty::Optional(inner, _) => format!("{}?", display_ty(inner)), Ty::Literal(lit, _freshness, _) => lit.to_string(), - Ty::Function { params, ret, .. } => { + Ty::Function { + params, + ret, + throws, + .. + } => { let ps: Vec = params .iter() .map(|(name, ty)| { @@ -159,7 +164,12 @@ pub fn display_ty(ty: &Ty) -> String { .unwrap_or_else(|| display_ty(ty)) }) .collect(); - format!("({}) -> {}", ps.join(", "), display_ty(ret)) + let mut rendered = format!("({}) -> {}", ps.join(", "), display_ty(ret)); + if !matches!(throws.as_ref(), Ty::Never { .. }) { + rendered.push_str(" throws "); + rendered.push_str(&display_ty(throws)); + } + rendered } Ty::TypeVar(name, _) => name.to_string(), Ty::Never { .. } => "never".to_string(), @@ -222,7 +232,12 @@ pub fn display_type_expr(te: &TypeExpr) -> String { parts.join(" | ") } TypeExpr::Literal { value, .. } => value.to_string(), - TypeExpr::Function { params, ret, .. } => { + TypeExpr::Function { + params, + ret, + throws, + .. + } => { let ps: Vec = params .iter() .map(|p| { @@ -232,7 +247,12 @@ pub fn display_type_expr(te: &TypeExpr) -> String { .unwrap_or_else(|| display_type_expr(&p.ty)) }) .collect(); - format!("({}) -> {}", ps.join(", "), display_type_expr(ret)) + let mut rendered = format!("({}) -> {}", ps.join(", "), display_type_expr(ret)); + if let Some(throws) = throws { + rendered.push_str(" throws "); + rendered.push_str(&display_type_expr(throws)); + } + rendered } TypeExpr::BuiltinUnknown { .. } => "unknown".to_string(), TypeExpr::Never { .. } => "never".to_string(), diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/array_map_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/array_map_throws.baml index f364825be1..147cf65910 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/array_map_throws.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/array_map_throws.baml @@ -1,13 +1,13 @@ // === Array.map with throwing callback === // Pure map callback -function test_map_pure() -> int[] { +function test_array_map_pure() -> int[] { let items: int[] = [1, 2, 3] items.map((x) -> { x * 2 }) } // Throwing map callback -function test_map_throwing() -> int[] { +function test_array_map_throwing() -> int[] { let items: int[] = [1, 2, 3] items.map((x) -> { if (x == 2) { throw "found two" } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/catch_absorbs_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/catch_absorbs_throws.baml index bd21c4a5a0..9184131682 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/catch_absorbs_throws.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/catch_absorbs_throws.baml @@ -1,32 +1,32 @@ // === catch should absorb throws === -function may_fail(x: int) -> string throws string { +function caught_may_fail(x: int) -> string throws string { if (x == 0) { throw "zero" } "ok" } // Calling a throwing function without catch - should propagate throws function test_no_catch() -> string { - may_fail(1) + caught_may_fail(1) } // Catching with wildcard - should absorb throws function test_catch_all() -> string { - may_fail(1) catch (e) { + caught_may_fail(1) catch (e) { _ => "caught" } } // Catching specific type - should absorb that type function test_catch_string() -> string { - may_fail(1) catch (e) { + caught_may_fail(1) catch (e) { _: string => "caught string" } } // Catch then re-throw - should propagate the re-thrown type function test_catch_rethrow() -> string { - may_fail(1) catch (e) { + caught_may_fail(1) catch (e) { _ => throw 42 } } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/fn_decl_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/fn_decl_throws.baml index edfe691922..76ededa5ce 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/fn_decl_throws.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/fn_decl_throws.baml @@ -1,7 +1,7 @@ // === Function declaration throws preserved in type === // Function with explicit throws - should show throws string in type -function may_fail(x: int) -> string throws string { +function declared_may_fail(x: int) -> string throws string { if (x == 0) { throw "zero" } "ok" } @@ -13,12 +13,12 @@ function always_ok(x: int) -> string { // Function calling a throwing function - should propagate throws function caller() -> string { - may_fail(1) + declared_may_fail(1) } // Function calling a throwing function with catch - should not propagate function safe_caller() -> string { - may_fail(1) catch (e) { + declared_may_fail(1) catch (e) { _ => "caught" } } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/fn_type_alias_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/fn_type_alias_throws.baml index 29091e4fbc..e657208689 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/fn_type_alias_throws.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/fn_type_alias_throws.baml @@ -9,8 +9,16 @@ type PureCallback = () -> int // Type alias with throws never explicitly type ExplicitPure = () -> int throws never +// Throws alias resolving to never should still lower to a closed function type +type NoThrow = never +type AliasPure = () -> int throws NoThrow + // Parameterized function type with throws type Mapper = (int) -> string throws string // Nested function type - outer throws, inner pure type Wrapper = (() -> int) -> int throws string + +function use_alias_pure(f: AliasPure) -> int { + f() +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/hof_rethrows.baml b/baml_language/crates/baml_tests/projects/function_type_throws/hof_rethrows.baml index 5da9797b74..6c79e9dc16 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/hof_rethrows.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/hof_rethrows.baml @@ -34,10 +34,10 @@ function map_it(x: int, f: (int) -> string) -> string { f(x) } -function test_map_pure() -> string { +function test_map_it_pure() -> string { map_it(1, (n: int) -> string { "ok" }) } -function test_map_throwing() -> string { +function test_map_it_throwing() -> string { map_it(1, (n: int) -> string { throw "bad" }) } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_violation.baml b/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_violation.baml index 85fcfff16e..e22a423bd3 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_violation.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/lambda_throws_violation.baml @@ -5,3 +5,15 @@ function test_throws_never_but_throws() -> int { let f = () -> int throws never { throw "boom" } f() } + +// Lambda checked against an expected callback type should not under-declare throws. +function test_typed_lambda_throws_mismatch() -> int { + let f: () -> int throws string = () -> int { throw 42 } + f() +} + +// Explicit lambda throws annotations must still cover the body's escaping throws. +function test_typed_lambda_annotation_underdeclares_body() -> int { + let f: () -> int throws string = () -> int throws string { throw 42 } + f() +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/optional_call_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/optional_call_throws.baml new file mode 100644 index 0000000000..ffa5a30ed2 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/optional_call_throws.baml @@ -0,0 +1,15 @@ +// === Optional call should propagate callback throws === + +function optional_call_rethrows(cb: (() -> int throws string)?) -> int? { + cb?.() +} + +function optional_call_caught(cb: (() -> int throws string)?) -> int? { + cb?.() catch (e) { + _ => null + } +} + +function test_optional_call_with_throwing_callback() -> int? { + optional_call_rethrows(() -> int { throw "boom" }) +} diff --git a/baml_language/crates/baml_tests/snapshots/catch_return_type_mismatch/baml_tests__catch_return_type_mismatch__04_tir.snap b/baml_language/crates/baml_tests/snapshots/catch_return_type_mismatch/baml_tests__catch_return_type_mismatch__04_tir.snap index 9d87ca9bf5..7fd82db80b 100644 --- a/baml_language/crates/baml_tests/snapshots/catch_return_type_mismatch/baml_tests__catch_return_type_mismatch__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/catch_return_type_mismatch/baml_tests__catch_return_type_mismatch__04_tir.snap @@ -9,7 +9,7 @@ function user.callee() -> bool throws user.Errors { } ?? 218..225: extraneous throws declaration: user.Errors } -function user.caller() -> bool throws never { +function user.caller() -> bool throws user.Errors { { : bool catch (callee() : bool) : unknown catch (e) diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__array_map_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__array_map_throws.snap index c8e8fe73d9..85ccaa222e 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__array_map_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__array_map_throws.snap @@ -19,7 +19,7 @@ Word "Pure" Word "map" Word "callback" Function "function" -Word "test_map_pure" +Word "test_array_map_pure" LParen "(" RParen ")" Arrow "->" @@ -62,7 +62,7 @@ Word "Throwing" Word "map" Word "callback" Function "function" -Word "test_map_throwing" +Word "test_array_map_throwing" LParen "(" RParen ")" Arrow "->" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__catch_absorbs_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__catch_absorbs_throws.snap index 6d68ac7c0f..d015f7774d 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__catch_absorbs_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__catch_absorbs_throws.snap @@ -12,7 +12,7 @@ Throws "throws" EqualsEquals "==" Equals "=" Function "function" -Word "may_fail" +Word "caught_may_fail" LParen "(" Word "x" Colon ":" @@ -58,7 +58,7 @@ RParen ")" Arrow "->" Word "string" LBrace "{" -Word "may_fail" +Word "caught_may_fail" LParen "(" IntegerLiteral "1" RParen ")" @@ -79,7 +79,7 @@ RParen ")" Arrow "->" Word "string" LBrace "{" -Word "may_fail" +Word "caught_may_fail" LParen "(" IntegerLiteral "1" RParen ")" @@ -112,7 +112,7 @@ RParen ")" Arrow "->" Word "string" LBrace "{" -Word "may_fail" +Word "caught_may_fail" LParen "(" IntegerLiteral "1" RParen ")" @@ -149,7 +149,7 @@ RParen ")" Arrow "->" Word "string" LBrace "{" -Word "may_fail" +Word "caught_may_fail" LParen "(" IntegerLiteral "1" RParen ")" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_decl_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_decl_throws.snap index 7a1084d342..8d37b49d64 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_decl_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_decl_throws.snap @@ -27,7 +27,7 @@ Word "string" In "in" Word "type" Function "function" -Word "may_fail" +Word "declared_may_fail" LParen "(" Word "x" Colon ":" @@ -98,7 +98,7 @@ RParen ")" Arrow "->" Word "string" LBrace "{" -Word "may_fail" +Word "declared_may_fail" LParen "(" IntegerLiteral "1" RParen ")" @@ -123,7 +123,7 @@ RParen ")" Arrow "->" Word "string" LBrace "{" -Word "may_fail" +Word "declared_may_fail" LParen "(" IntegerLiteral "1" RParen ")" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_type_alias_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_type_alias_throws.snap index 3ee4e07410..ae9fdbb2aa 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_type_alias_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__fn_type_alias_throws.snap @@ -65,6 +65,34 @@ Throws "throws" Word "never" Slash "/" Slash "/" +Word "Throws" +Word "alias" +Word "resolving" +Word "to" +Word "never" +Word "should" +Word "still" +Word "lower" +Word "to" +Word "a" +Word "closed" +Function "function" +Word "type" +Word "type" +Word "NoThrow" +Equals "=" +Word "never" +Word "type" +Word "AliasPure" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "NoThrow" +Slash "/" +Slash "/" Word "Parameterized" Function "function" Word "type" @@ -104,3 +132,17 @@ Arrow "->" Word "int" Throws "throws" Word "string" +Function "function" +Word "use_alias_pure" +LParen "(" +Word "f" +Colon ":" +Word "AliasPure" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_rethrows.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_rethrows.snap index b0d0ceed85..88900c156f 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_rethrows.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_rethrows.snap @@ -252,7 +252,7 @@ Word "x" RParen ")" RBrace "}" Function "function" -Word "test_map_pure" +Word "test_map_it_pure" LParen "(" RParen ")" Arrow "->" @@ -277,7 +277,7 @@ RBrace "}" RParen ")" RBrace "}" Function "function" -Word "test_map_throwing" +Word "test_map_it_throwing" LParen "(" RParen ")" Arrow "->" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_violation.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_violation.snap index 485cddc227..cc1ff8c3fc 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_violation.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__lambda_throws_violation.snap @@ -55,3 +55,93 @@ Word "f" LParen "(" RParen ")" RBrace "}" +Slash "/" +Slash "/" +Word "Lambda" +Word "checked" +Word "against" +Word "an" +Word "expected" +Word "callback" +Word "type" +Word "should" +Word "not" +Word "under-declare" +Throws "throws" +Dot "." +Function "function" +Word "test_typed_lambda_throws_mismatch" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Explicit" +Word "lambda" +Throws "throws" +Word "annotations" +Word "must" +Word "still" +Word "cover" +Word "the" +Word "body" +Error "'" +Word "s" +Word "escaping" +Throws "throws" +Dot "." +Function "function" +Word "test_typed_lambda_annotation_underdeclares_body" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +Word "f" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__optional_call_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__optional_call_throws.snap new file mode 100644 index 0000000000..81abe516e3 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__optional_call_throws.snap @@ -0,0 +1,94 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Optional" +Word "call" +Word "should" +Word "propagate" +Word "callback" +Throws "throws" +EqualsEquals "==" +Equals "=" +Function "function" +Word "optional_call_rethrows" +LParen "(" +Word "cb" +Colon ":" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +RParen ")" +Question "?" +RParen ")" +Arrow "->" +Word "int" +Question "?" +LBrace "{" +Word "cb" +QuestionDot "?." +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "optional_call_caught" +LParen "(" +Word "cb" +Colon ":" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "string" +RParen ")" +Question "?" +RParen ")" +Arrow "->" +Word "int" +Question "?" +LBrace "{" +Word "cb" +QuestionDot "?." +LParen "(" +RParen ")" +Catch "catch" +LParen "(" +Word "e" +RParen ")" +LBrace "{" +Word "_" +FatArrow "=>" +Word "null" +RBrace "}" +RBrace "}" +Function "function" +Word "test_optional_call_with_throwing_callback" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Question "?" +LBrace "{" +Word "optional_call_rethrows" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "boom" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__array_map_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__array_map_throws.snap index 53b8a0c345..75178d9639 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__array_map_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__array_map_throws.snap @@ -5,7 +5,7 @@ source: crates/baml_tests/src/generated_tests.rs SOURCE_FILE FUNCTION_DEF KW_FUNCTION "function" - WORD "test_map_pure" + WORD "test_array_map_pure" PARAMETER_LIST "()" L_PAREN "(" R_PAREN ")" @@ -59,7 +59,7 @@ SOURCE_FILE R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" - WORD "test_map_throwing" + WORD "test_array_map_throwing" PARAMETER_LIST "()" L_PAREN "(" R_PAREN ")" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__catch_absorbs_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__catch_absorbs_throws.snap index ddc109d949..6fab679f56 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__catch_absorbs_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__catch_absorbs_throws.snap @@ -5,7 +5,7 @@ source: crates/baml_tests/src/generated_tests.rs SOURCE_FILE FUNCTION_DEF KW_FUNCTION "function" - WORD "may_fail" + WORD "caught_may_fail" PARAMETER_LIST L_PAREN "(" PARAMETER @@ -61,7 +61,7 @@ SOURCE_FILE BLOCK_EXPR L_BRACE "{" CALL_EXPR - WORD "may_fail" + WORD "caught_may_fail" CALL_ARGS "(1)" L_PAREN "(" INTEGER_LITERAL "1" @@ -81,7 +81,7 @@ SOURCE_FILE L_BRACE "{" CATCH_EXPR CALL_EXPR - WORD "may_fail" + WORD "caught_may_fail" CALL_ARGS "(1)" L_PAREN "(" INTEGER_LITERAL "1" @@ -117,7 +117,7 @@ SOURCE_FILE L_BRACE "{" CATCH_EXPR CALL_EXPR - WORD "may_fail" + WORD "caught_may_fail" CALL_ARGS "(1)" L_PAREN "(" INTEGER_LITERAL "1" @@ -157,7 +157,7 @@ SOURCE_FILE L_BRACE "{" CATCH_EXPR CALL_EXPR - WORD "may_fail" + WORD "caught_may_fail" CALL_ARGS "(1)" L_PAREN "(" INTEGER_LITERAL "1" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_decl_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_decl_throws.snap index 517de1bd1f..5b175096a3 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_decl_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_decl_throws.snap @@ -5,7 +5,7 @@ source: crates/baml_tests/src/generated_tests.rs SOURCE_FILE FUNCTION_DEF KW_FUNCTION "function" - WORD "may_fail" + WORD "declared_may_fail" PARAMETER_LIST L_PAREN "(" PARAMETER @@ -84,7 +84,7 @@ SOURCE_FILE BLOCK_EXPR L_BRACE "{" CALL_EXPR - WORD "may_fail" + WORD "declared_may_fail" CALL_ARGS "(1)" L_PAREN "(" INTEGER_LITERAL "1" @@ -104,7 +104,7 @@ SOURCE_FILE L_BRACE "{" CATCH_EXPR CALL_EXPR - WORD "may_fail" + WORD "declared_may_fail" CALL_ARGS "(1)" L_PAREN "(" INTEGER_LITERAL "1" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_type_alias_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_type_alias_throws.snap index 09c1501ea9..76885d6465 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_type_alias_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__fn_type_alias_throws.snap @@ -41,6 +41,26 @@ SOURCE_FILE KW_THROWS "throws" TYPE_EXPR "never" WORD "never" + TYPE_ALIAS_DEF + WORD "type" + WORD "NoThrow" + EQUALS "=" + TYPE_EXPR "never" + WORD "never" + TYPE_ALIAS_DEF + WORD "type" + WORD "AliasPure" + EQUALS "=" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "NoThrow" + WORD "NoThrow" TYPE_ALIAS_DEF WORD "type" WORD "Mapper" @@ -79,6 +99,29 @@ SOURCE_FILE KW_THROWS "throws" TYPE_EXPR "string" WORD "string" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "use_alias_pure" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR "AliasPure" + WORD "AliasPure" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" === ERRORS === None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_rethrows.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_rethrows.snap index 6f5e45b250..9cea3d30ff 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_rethrows.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_rethrows.snap @@ -342,7 +342,7 @@ SOURCE_FILE R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" - WORD "test_map_pure" + WORD "test_map_it_pure" PARAMETER_LIST "()" L_PAREN "(" R_PAREN ")" @@ -381,7 +381,7 @@ SOURCE_FILE R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" - WORD "test_map_throwing" + WORD "test_map_it_throwing" PARAMETER_LIST "()" L_PAREN "(" R_PAREN ")" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_violation.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_violation.snap index e677bdcdfa..8dff63cb37 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_violation.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__lambda_throws_violation.snap @@ -46,6 +46,104 @@ SOURCE_FILE L_PAREN "(" R_PAREN ")" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_typed_lambda_throws_mismatch" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_typed_lambda_annotation_underdeclares_body" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" === ERRORS === None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__optional_call_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__optional_call_throws.snap new file mode 100644 index 0000000000..216418601f --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__optional_call_throws.snap @@ -0,0 +1,134 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "optional_call_rethrows" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "cb" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + QUESTION "?" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int?" + WORD "int" + QUESTION "?" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OPTIONAL_CALL_EXPR + WORD "cb" + QUESTION_DOT "?." + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "optional_call_caught" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "cb" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "string" + WORD "string" + R_PAREN ")" + QUESTION "?" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int?" + WORD "int" + QUESTION "?" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CATCH_EXPR + OPTIONAL_CALL_EXPR + WORD "cb" + QUESTION_DOT "?." + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + CATCH_CLAUSE + KW_CATCH "catch" + L_PAREN "(" + CATCH_PATTERN "e" + WORD "e" + R_PAREN ")" + L_BRACE "{" + CATCH_ARM + CATCH_PATTERN "_" + WORD "_" + FAT_ARROW "=>" + WORD "null" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_optional_call_with_throwing_callback" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int?" + WORD "int" + QUESTION "?" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "optional_call_rethrows" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "boom" + QUOTE """ + WORD "boom" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index e649b3c3c7..431c8e0b62 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -2,30 +2,30 @@ source: crates/baml_tests/src/generated_tests.rs --- === HIR2 === -function user.test_map_pure() -> int[] [expr] { +function user.test_array_map_pure() -> int[] [expr] { { let items: int[] = [1, 2, 3] } items.map((x) -> { { } x Mul 2 }) } -function user.test_map_throwing() -> int[] [expr] { +function user.test_array_map_throwing() -> int[] [expr] { { let items: int[] = [1, 2, 3] } items.map((x) -> { { if (x Eq 2) { throw "found two" } } x Mul 2 }) } type user.ThrowingFn = () -> int function user.takes_throwing(f: () -> int) -> int [expr] { { } f() } -function user.may_fail(x: int) -> string [expr] { +function user.caught_may_fail(x: int) -> string [expr] { { if (x Eq 0) { throw "zero" } } "ok" } function user.test_catch_all() -> string [expr] { - { } may_fail(1) catch (e) { _ => "caught" } + { } caught_may_fail(1) catch (e) { _ => "caught" } } function user.test_catch_rethrow() -> string [expr] { - { } may_fail(1) catch (e) { _ => throw 42 } + { } caught_may_fail(1) catch (e) { _ => throw 42 } } function user.test_catch_string() -> string [expr] { - { } may_fail(1) catch (e) { _: string => "caught string" } + { } caught_may_fail(1) catch (e) { _: string => "caught string" } } function user.test_no_catch() -> string [expr] { - { } may_fail(1) + { } caught_may_fail(1) } function user.apply_inner(f: () -> int) -> int [expr] { { } f() @@ -194,19 +194,24 @@ function user.always_ok(x: int) -> string [expr] { { } "always ok" } function user.caller() -> string [expr] { - { } may_fail(1) + { } declared_may_fail(1) } -function user.may_fail(x: int) -> string [expr] { +function user.declared_may_fail(x: int) -> string [expr] { { if (x Eq 0) { throw "zero" } } "ok" } function user.safe_caller() -> string [expr] { - { } may_fail(1) catch (e) { _ => "caught" } + { } declared_may_fail(1) catch (e) { _ => "caught" } } +type user.AliasPure = () -> int type user.ExplicitPure = () -> int type user.Mapper = (int) -> string +type user.NoThrow = never type user.PureCallback = () -> int type user.ThrowingCallback = () -> int type user.Wrapper = (() -> int) -> int +function user.use_alias_pure(f: user.AliasPure) -> int [expr] { + { } f() +} function user.apply_guarded(f: () -> int) -> int [expr] { { let result = f(); if (result Lt 0) { throw "negative result" } } result } @@ -228,10 +233,10 @@ function user.run_throwing(f: () -> int) -> int [expr] { function user.run_two(f: () -> int, g: () -> int) -> int [expr] { { } f() Add g() } -function user.test_map_pure() -> string [expr] { +function user.test_map_it_pure() -> string [expr] { { } map_it(1, (n: int) -> string { { } "ok" }) } -function user.test_map_throwing() -> string [expr] { +function user.test_map_it_throwing() -> string [expr] { { } map_it(1, (n: int) -> string { { throw "bad" } }) } function user.test_run_pure() -> int [expr] { @@ -309,6 +314,12 @@ function user.test_throwing_lambda() -> int [expr] { function user.test_throws_never_but_throws() -> int [expr] { { let f = () -> int throws never { { throw "boom" } } } f() } +function user.test_typed_lambda_annotation_underdeclares_body() -> int [expr] { + { let f: () -> int = () -> int throws string { { throw 42 } } } f() +} +function user.test_typed_lambda_throws_mismatch() -> int [expr] { + { let f: () -> int = () -> int { { throw 42 } } } f() +} function user.apply_with_arg(x: int, f: (int) -> int) -> int [expr] { { } f(x) } @@ -333,6 +344,15 @@ function user.test_nested_inner_throws() -> int [expr] { function user.test_nested_outer_throws() -> int [expr] { { let outer = () -> int { { let inner = () -> int { { } 42 }; throw "outer boom" } } } outer() } +function user.optional_call_caught(cb: () -> int?) -> int? [expr] { + { } cb?.() catch (e) { _ => null } +} +function user.optional_call_rethrows(cb: () -> int?) -> int? [expr] { + { } cb?.() +} +function user.test_optional_call_with_throwing_callback() -> int? [expr] { + { } optional_call_rethrows(() -> int { { throw "boom" } }) +} function user.make_pure() -> () -> int [expr] { { return () -> int { { } 42 } } } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index 6b97a629ec..6754580aca 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -2,7 +2,7 @@ source: crates/baml_tests/src/generated_tests.rs --- === MIR2 === -fn user.test_map_pure() -> int[] { +fn user.test_array_map_pure() -> int[] { // Locals: let _0: int[] // _0 // return let _1: int[] // items @@ -26,7 +26,7 @@ fn user.test_map_pure() -> int[] { } // lambda[0] -fn .(x: null) -> null { +fn .(x: null) -> null { // Locals: let _0: null // _0 // return let _1: null // x // param @@ -41,12 +41,12 @@ fn .(x: null) -> null { } } -fn user.test_map_throwing() -> int[] { +fn user.test_array_map_throwing() -> int[] { // Locals: let _0: int[] // _0 // return let _1: int[] // items let _2: int[] - let _3: (int) -> int + let _3: (int) -> int throws string bb0: { _1 = [const 1_i64, const 2_i64, const 3_i64]; @@ -65,7 +65,7 @@ fn user.test_map_throwing() -> int[] { } // lambda[0] -fn .(x: null) -> null { +fn .(x: null) -> null { // Locals: let _0: null // _0 // return let _1: null // x // param @@ -94,10 +94,10 @@ fn .(x: null) -> null { } } -fn user.takes_throwing(f: () -> int) -> int { +fn user.takes_throwing(f: () -> int throws string) -> int { // Locals: let _0: int // _0 // return - let _1: () -> int // f // param + let _1: () -> int throws string // f // param bb0: { _0 = call copy _1() -> [bb1]; @@ -112,7 +112,7 @@ fn user.takes_throwing(f: () -> int) -> int { } } -fn user.may_fail(x: int) -> string { +fn user.caught_may_fail(x: int) -> string { // Locals: let _0: string // _0 // return let _1: int // x // param @@ -147,7 +147,7 @@ fn user.test_catch_all() -> string { let _1: unknown // e bb0: { - _0 = call const fn user.may_fail(const 1_i64) -> [bb1, unwind: bb2]; + _0 = call const fn user.caught_may_fail(const 1_i64) -> [bb1, unwind: bb2]; } bb1: { @@ -182,7 +182,7 @@ fn user.test_catch_rethrow() -> string { let _1: unknown // e bb0: { - _0 = call const fn user.may_fail(const 1_i64) -> [bb1, unwind: bb4]; + _0 = call const fn user.caught_may_fail(const 1_i64) -> [bb1, unwind: bb4]; } bb1: { @@ -217,7 +217,7 @@ fn user.test_catch_string() -> string { let _2: bool bb0: { - _0 = call const fn user.may_fail(const 1_i64) -> [bb1, unwind: bb2]; + _0 = call const fn user.caught_may_fail(const 1_i64) -> [bb1, unwind: bb2]; } bb1: { @@ -252,7 +252,7 @@ fn user.test_no_catch() -> string { let _0: string // _0 // return bb0: { - _0 = call const fn user.may_fail(const 1_i64) -> [bb1]; + _0 = call const fn user.caught_may_fail(const 1_i64) -> [bb1]; } bb1: { @@ -337,7 +337,7 @@ fn .() -> null { fn user.test_chained_throwing() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -397,7 +397,7 @@ fn .() -> null { fn user.make_throwing_handler() -> ThrowingHandler { // Locals: let _0: ThrowingHandler // _0 // return - let _1: () -> void + let _1: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -467,8 +467,8 @@ fn .(a: void) -> null { fn user.test_compose_both_throw() -> (int) -> string { // Locals: let _0: (int) -> string // _0 // return - let _1: (int) -> void - let _2: (int) -> void + let _1: (int) -> void throws string + let _2: (int) -> void throws int bb0: { _1 = make_closure lambda[0](); @@ -510,7 +510,7 @@ fn .(y: int) -> null { fn user.test_compose_first_throws() -> (int) -> string { // Locals: let _0: (int) -> string // _0 // return - let _1: (int) -> void + let _1: (int) -> void throws string let _2: (int) -> "result" bb0: { @@ -612,7 +612,7 @@ fn user.test_compose_second_throws() -> (int) -> string { // Locals: let _0: (int) -> string // _0 // return let _1: (int) -> int - let _2: (int) -> void + let _2: (int) -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -692,10 +692,10 @@ fn user.apply_generic_enum(f: () -> int) -> int { } } -fn user.apply_may_throw_class(f: () -> int) -> int { +fn user.apply_may_throw_class(f: () -> int throws ApiError) -> int { // Locals: let _0: int // _0 // return - let _1: () -> int // f // param + let _1: () -> int throws ApiError // f // param bb0: { _0 = call copy _1() -> [bb1]; @@ -710,10 +710,10 @@ fn user.apply_may_throw_class(f: () -> int) -> int { } } -fn user.apply_may_throw_enum(f: () -> int) -> int { +fn user.apply_may_throw_enum(f: () -> int throws ErrorKind) -> int { // Locals: let _0: int // _0 // return - let _1: () -> int // f // param + let _1: () -> int throws ErrorKind // f // param bb0: { _0 = call copy _1() -> [bb1]; @@ -889,7 +889,7 @@ fn user.catch_mixed_errors(x: int) -> string { fn user.make_class_handler() -> ClassThrowingHandler { // Locals: let _0: ClassThrowingHandler // _0 // return - let _1: () -> void + let _1: () -> void throws ApiError bb0: { _1 = make_closure lambda[0](); @@ -917,7 +917,7 @@ fn .() -> null { fn user.make_enum_handler() -> EnumThrowingHandler { // Locals: let _0: EnumThrowingHandler // _0 // return - let _1: () -> void + let _1: () -> void throws ErrorKind.NotFound bb0: { _1 = make_closure lambda[0](); @@ -943,7 +943,7 @@ fn .() -> null { fn user.test_apply_class_thrower() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws ApiError bb0: { _1 = make_closure lambda[0](); @@ -974,7 +974,7 @@ fn .() -> null { fn user.test_apply_enum_thrower() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws ErrorKind.Unauthorized bb0: { _1 = make_closure lambda[0](); @@ -1003,8 +1003,8 @@ fn .() -> null { fn user.test_lambda_throws_class() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void // f - let _2: () -> void + let _1: () -> void throws ValidationError // f + let _2: () -> void throws ValidationError bb0: { _1 = make_closure lambda[0](); @@ -1036,8 +1036,8 @@ fn .() -> null { fn user.test_lambda_throws_enum() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void // f - let _2: () -> void + let _1: () -> void throws ErrorKind.ValidationFailed // f + let _2: () -> void throws ErrorKind.ValidationFailed bb0: { _1 = make_closure lambda[0](); @@ -1067,7 +1067,7 @@ fn .() -> null { fn user.test_rethrows_class() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws ValidationError bb0: { _1 = make_closure lambda[0](); @@ -1098,7 +1098,7 @@ fn .() -> null { fn user.test_rethrows_enum() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws ErrorKind.NotFound bb0: { _1 = make_closure lambda[0](); @@ -1127,7 +1127,7 @@ fn .() -> null { fn user.test_type_alias_class() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws ApiError bb0: { _1 = make_closure lambda[0](); @@ -1158,7 +1158,7 @@ fn .() -> null { fn user.test_type_alias_enum() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws ErrorKind.ValidationFailed bb0: { _1 = make_closure lambda[0](); @@ -1187,7 +1187,7 @@ fn .() -> null { fn user.test_type_alias_mixed() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -1451,10 +1451,10 @@ fn user.throw_various_errors(x: int) -> string { } } -fn user.use_class_thrower(f: () -> int) -> int { +fn user.use_class_thrower(f: () -> int throws ApiError) -> int { // Locals: let _0: int // _0 // return - let _1: () -> int // f // param + let _1: () -> int throws ApiError // f // param bb0: { _0 = call copy _1() -> [bb1]; @@ -1469,10 +1469,10 @@ fn user.use_class_thrower(f: () -> int) -> int { } } -fn user.use_enum_thrower(f: () -> int) -> int { +fn user.use_enum_thrower(f: () -> int throws ErrorKind) -> int { // Locals: let _0: int // _0 // return - let _1: () -> int // f // param + let _1: () -> int throws ErrorKind // f // param bb0: { _0 = call copy _1() -> [bb1]; @@ -1487,10 +1487,10 @@ fn user.use_enum_thrower(f: () -> int) -> int { } } -fn user.use_mixed_thrower(f: () -> int) -> int { +fn user.use_mixed_thrower(f: () -> int throws ErrorKind | ApiError | string) -> int { // Locals: let _0: int // _0 // return - let _1: () -> int // f // param + let _1: () -> int throws ErrorKind | ApiError | string // f // param bb0: { _0 = call copy _1() -> [bb1]; @@ -1505,10 +1505,10 @@ fn user.use_mixed_thrower(f: () -> int) -> int { } } -fn user.apply_explicit(f: () -> int) -> int { +fn user.apply_explicit(f: () -> int throws string) -> int { // Locals: let _0: int // _0 // return - let _1: () -> int // f // param + let _1: () -> int throws string // f // param bb0: { _0 = call copy _1() -> [bb1]; @@ -1578,7 +1578,7 @@ fn .() -> null { fn user.test_explicit_param_throwing() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -1658,7 +1658,7 @@ fn user.caller() -> string { let _0: string // _0 // return bb0: { - _0 = call const fn user.may_fail(const 1_i64) -> [bb1]; + _0 = call const fn user.declared_may_fail(const 1_i64) -> [bb1]; } bb1: { @@ -1670,7 +1670,7 @@ fn user.caller() -> string { } } -fn user.may_fail(x: int) -> string { +fn user.declared_may_fail(x: int) -> string { // Locals: let _0: string // _0 // return let _1: int // x // param @@ -1705,7 +1705,7 @@ fn user.safe_caller() -> string { let _1: unknown // e bb0: { - _0 = call const fn user.may_fail(const 1_i64) -> [bb1, unwind: bb2]; + _0 = call const fn user.declared_may_fail(const 1_i64) -> [bb1, unwind: bb2]; } bb1: { @@ -1734,6 +1734,24 @@ fn user.safe_caller() -> string { } } +fn user.use_alias_pure(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + fn user.apply_guarded(f: () -> int) -> int { // Locals: let _0: int // _0 // return @@ -1807,7 +1825,7 @@ fn .() -> null { fn user.test_guarded_throwing() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws int bb0: { _1 = make_closure lambda[0](); @@ -1914,7 +1932,7 @@ fn user.run_two(f: () -> int, g: () -> int) -> int { } } -fn user.test_map_pure() -> string { +fn user.test_map_it_pure() -> string { // Locals: let _0: string // _0 // return let _1: (int) -> "ok" @@ -1934,7 +1952,7 @@ fn user.test_map_pure() -> string { } // lambda[0] -fn .(n: int) -> null { +fn .(n: int) -> null { // Locals: let _0: null // _0 // return let _1: int // n // param @@ -1949,10 +1967,10 @@ fn .(n: int) -> null { } } -fn user.test_map_throwing() -> string { +fn user.test_map_it_throwing() -> string { // Locals: let _0: string // _0 // return - let _1: (int) -> void + let _1: (int) -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -1969,7 +1987,7 @@ fn user.test_map_throwing() -> string { } // lambda[0] -fn .(n: int) -> null { +fn .(n: int) -> null { // Locals: let _0: null // _0 // return let _1: int // n // param @@ -2016,7 +2034,7 @@ fn .() -> null { fn user.test_run_throwing() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -2045,8 +2063,8 @@ fn .() -> null { fn user.test_two_both_throw() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void - let _2: () -> void + let _1: () -> void throws string + let _2: () -> void throws int bb0: { _1 = make_closure lambda[0](); @@ -2087,7 +2105,7 @@ fn user.test_two_one_throws() -> int { // Locals: let _0: int // _0 // return let _1: () -> 1 - let _2: () -> void + let _2: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -2234,10 +2252,10 @@ fn user.apply_and_throw(f: () -> int) -> int { } } -fn user.apply_throwing(f: () -> int) -> int { +fn user.apply_throwing(f: () -> int throws string) -> int { // Locals: let _0: int // _0 // return - let _1: () -> int // f // param + let _1: () -> int throws string // f // param bb0: { _0 = call copy _1() -> [bb1]; @@ -2321,7 +2339,7 @@ fn .() -> null { fn user.test_apply_explicit_throws() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -2384,7 +2402,7 @@ fn .() -> null { fn user.test_apply_throwing() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -2447,7 +2465,7 @@ fn .() -> null { fn user.test_apply_with_helper_throwing() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void + let _1: () -> void throws int bb0: { _1 = make_closure lambda[0](); @@ -2476,8 +2494,8 @@ fn .() -> null { fn user.test_explicit_throws_match() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void // f - let _2: () -> void + let _1: () -> void throws string // f + let _2: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -2543,8 +2561,8 @@ fn .() -> null { fn user.test_explicit_throws_wider_than_body() -> int { // Locals: let _0: int // _0 // return - let _1: () -> 42 // f - let _2: () -> 42 + let _1: () -> 42 throws string // f + let _2: () -> 42 throws string bb0: { _1 = make_closure lambda[0](); @@ -2580,8 +2598,8 @@ fn user.test_conditional_throw(x: int) -> int { // Locals: let _0: int // _0 // return let _1: int // x // param - let _2: (int) -> int // f - let _3: (int) -> int + let _2: (int) -> int throws string // f + let _3: (int) -> int throws string bb0: { _2 = make_closure lambda[0](); @@ -2632,8 +2650,8 @@ fn user.test_multi_throw_types(x: int) -> string { // Locals: let _0: string // _0 // return let _1: int // x // param - let _2: (int) -> "ok" // f - let _3: (int) -> "ok" + let _2: (int) -> "ok" throws int | string // f + let _3: (int) -> "ok" throws int | string bb0: { _2 = make_closure lambda[0](); @@ -2721,8 +2739,8 @@ fn .() -> null { fn user.test_throwing_int() -> string { // Locals: let _0: string // _0 // return - let _1: () -> void // f - let _2: () -> void + let _1: () -> void throws int // f + let _2: () -> void throws int bb0: { _1 = make_closure lambda[0](); @@ -2752,8 +2770,8 @@ fn .() -> null { fn user.test_throwing_lambda() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void // f - let _2: () -> void + let _1: () -> void throws string // f + let _2: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -2811,6 +2829,68 @@ fn .() -> null { } } +fn user.test_typed_lambda_annotation_underdeclares_body() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void throws string // f + let _2: () -> void throws string + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const 42_i64; + } +} + +fn user.test_typed_lambda_throws_mismatch() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void throws int // f + let _2: () -> void throws int + + bb0: { + _1 = make_closure lambda[0](); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const 42_i64; + } +} + fn user.apply_with_arg(x: int, f: (int) -> int) -> int { // Locals: let _0: int // _0 // return @@ -2924,7 +3004,7 @@ fn .(n: int) -> null { fn user.test_mixed_throwing() -> int { // Locals: let _0: int // _0 // return - let _1: (int) -> int + let _1: (int) -> int throws string bb0: { _1 = make_closure lambda[0](); @@ -2973,8 +3053,8 @@ fn .(n: int) -> null { fn user.test_nested_both_throw() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void // outer - let _2: () -> void + let _1: () -> void throws int | string // outer + let _2: () -> void throws int | string bb0: { _1 = make_closure lambda[0](); @@ -2995,9 +3075,9 @@ fn user.test_nested_both_throw() -> int { fn .() -> null { // Locals: let _0: null // _0 // return - let _1: () -> void // inner + let _1: () -> void throws int // inner let _2: void - let _3: () -> void + let _3: () -> void throws int bb0: { _1 = make_closure lambda[0](); @@ -3023,8 +3103,8 @@ fn ., 1)>() -> null { fn user.test_nested_inner_throws() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void // outer - let _2: () -> void + let _1: () -> void throws string // outer + let _2: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -3045,8 +3125,8 @@ fn user.test_nested_inner_throws() -> int { fn .() -> null { // Locals: let _0: null // _0 // return - let _1: () -> void // inner - let _2: () -> void + let _1: () -> void throws string // inner + let _2: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -3076,8 +3156,8 @@ fn ., 1)>() -> null { fn user.test_nested_outer_throws() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void // outer - let _2: () -> void + let _1: () -> void throws string // outer + let _2: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -3119,6 +3199,119 @@ fn ., 1)>() -> null { } } +fn user.optional_call_caught(cb: () -> int throws string?) -> int? { + // Locals: + let _0: int? // _0 // return + let _1: () -> int throws string? // cb // param + let _2: unknown // e + let _3: bool + + bb0: { + _3 = copy _1 == const null; + branch copy _3 -> [bb3, bb1]; + } + + bb1: { + _0 = call copy _1() -> [bb2, unwind: bb5]; + } + + bb2: { + goto -> bb4; + } + + bb3: { + _0 = const null; + goto -> bb4; + } + + bb4: { + goto -> bb8; + } + + bb5: { + throw_if_panic copy _2 -> bb6; + } + + bb6: { + goto -> bb7; + } + + bb7: { + _0 = const null; + goto -> bb8; + } + + bb8: { + goto -> bb9; + } + + bb9: { + return; + } +} + +fn user.optional_call_rethrows(cb: () -> int throws string?) -> int? { + // Locals: + let _0: int? // _0 // return + let _1: () -> int throws string? // cb // param + let _2: bool + + bb0: { + _2 = copy _1 == const null; + branch copy _2 -> [bb3, bb1]; + } + + bb1: { + _0 = call copy _1() -> [bb2]; + } + + bb2: { + goto -> bb4; + } + + bb3: { + _0 = const null; + goto -> bb4; + } + + bb4: { + goto -> bb5; + } + + bb5: { + return; + } +} + +fn user.test_optional_call_with_throwing_callback() -> int? { + // Locals: + let _0: int? // _0 // return + let _1: () -> void throws string + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.optional_call_rethrows(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "boom"; + } +} + fn user.make_pure() -> () -> int { // Locals: let _0: () -> int // _0 // return @@ -3148,9 +3341,9 @@ fn .() -> null { } } -fn user.make_thrower() -> () -> int { +fn user.make_thrower() -> () -> int throws string { // Locals: - let _0: () -> int // _0 // return + let _0: () -> int throws string // _0 // return bb0: { _0 = make_closure lambda[0](); @@ -3199,8 +3392,8 @@ fn user.test_use_pure() -> int { fn user.test_use_thrower() -> int { // Locals: let _0: int // _0 // return - let _1: () -> int // f - let _2: () -> int + let _1: () -> int throws string // f + let _2: () -> int throws string bb0: { _1 = call const fn user.make_thrower() -> [bb1]; @@ -3223,7 +3416,7 @@ fn user.test_use_thrower() -> int { fn user.make_bad_stored_handler() -> StoredPureHandler { // Locals: let _0: StoredPureHandler // _0 // return - let _1: () -> void + let _1: () -> void throws string bb0: { _1 = make_closure lambda[0](); @@ -3330,9 +3523,9 @@ fn .() -> null { } } -fn user.make_stored_closure_with_throws() -> () -> int { +fn user.make_stored_closure_with_throws() -> () -> int throws string { // Locals: - let _0: () -> int // _0 // return + let _0: () -> int throws string // _0 // return bb0: { _0 = make_closure lambda[0](); @@ -3357,7 +3550,7 @@ fn .() -> null { fn user.make_stored_throwing_handler() -> StoredThrowingHandler { // Locals: let _0: StoredThrowingHandler // _0 // return - let _1: () -> void + let _1: () -> void throws string bb0: { _1 = make_closure lambda[0](); diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 5b63867b30..1623f8b300 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -2,7 +2,7 @@ source: crates/baml_tests/src/generated_tests.rs --- === TIR2 === -function user.test_map_pure() -> int[] throws never { +function user.test_array_map_pure() -> int[] throws never { { : int[] let items = [1, 2, 3] : int[] items.map((x) -> { ... }) : int[] @@ -12,9 +12,9 @@ function user.test_map_pure() -> int[] throws never { } } } -lambda user.test_map_pure { +lambda user.test_array_map_pure { } -function user.test_map_throwing() -> int[] throws string { +function user.test_array_map_throwing() -> int[] throws string { { : int[] let items = [1, 2, 3] : int[] items.map((x) -> { ... }) : int[] @@ -28,7 +28,7 @@ function user.test_map_throwing() -> int[] throws string { } } } -lambda user.test_map_throwing { +lambda user.test_array_map_throwing { } type user.ThrowingFn = () -> int throws string function user.takes_throwing(f: () -> int throws string) -> int throws string { @@ -37,7 +37,7 @@ function user.takes_throwing(f: () -> int throws string) -> int throws string { } } type user.ThrowingFn$stream = unknown -function user.may_fail(x: int) -> string throws string { +function user.caught_may_fail(x: int) -> string throws string { { : "ok" if (x == 0 : bool) : void { : never @@ -48,12 +48,12 @@ function user.may_fail(x: int) -> string throws string { } function user.test_no_catch() -> string throws string { { : string - may_fail(1) : string + caught_may_fail(1) : string } } function user.test_catch_all() -> string throws never { { : string | "caught" - catch (may_fail(1) : string) : unknown + catch (caught_may_fail(1) : string) : unknown catch (e) _ => "caught" : "caught" @@ -61,7 +61,7 @@ function user.test_catch_all() -> string throws never { } function user.test_catch_string() -> string throws never { { : string | "caught string" - catch (may_fail(1) : string) : unknown + catch (caught_may_fail(1) : string) : unknown catch (e) _: string => "caught string" : "caught string" @@ -69,7 +69,7 @@ function user.test_catch_string() -> string throws never { } function user.test_catch_rethrow() -> string throws int { { : string - catch (may_fail(1) : string) : unknown + catch (caught_may_fail(1) : string) : unknown catch (e) _ => throw 42 : never @@ -560,7 +560,7 @@ function user.test_pure_only_pure() -> int throws never { } lambda user.test_pure_only_pure { } -function user.may_fail(x: int) -> string throws string { +function user.declared_may_fail(x: int) -> string throws string { { : "ok" if (x == 0 : bool) : void { : never @@ -576,12 +576,12 @@ function user.always_ok(x: int) -> string throws never { } function user.caller() -> string throws string { { : string - may_fail(1) : string + declared_may_fail(1) : string } } function user.safe_caller() -> string throws never { { : string | "caught" - catch (may_fail(1) : string) : unknown + catch (declared_may_fail(1) : string) : unknown catch (e) _ => "caught" : "caught" @@ -590,11 +590,20 @@ function user.safe_caller() -> string throws never { type user.ThrowingCallback = () -> int throws string type user.PureCallback = () -> int type user.ExplicitPure = () -> int +type user.NoThrow = never +type user.AliasPure = () -> int throws user.NoThrow type user.Mapper = (int) -> string throws string type user.Wrapper = (() -> int) -> int throws string +function user.use_alias_pure(f: user.AliasPure) -> int throws user.NoThrow { + { : int + f() : int + } +} type user.ThrowingCallback$stream = unknown type user.PureCallback$stream = unknown type user.ExplicitPure$stream = unknown +type user.NoThrow$stream = never +type user.AliasPure$stream = unknown type user.Mapper$stream = unknown type user.Wrapper$stream = unknown function user.apply_guarded(f: () -> int throws __throws_f) -> int throws __throws_f | string { @@ -722,7 +731,7 @@ function user.map_it(x: int, f: (int) -> string throws __throws_f) -> string thr f(x) : string } } -function user.test_map_pure() -> string throws never { +function user.test_map_it_pure() -> string throws never { { : string map_it(1, (n: int) -> string { ... }) : string (n: int) -> string { ... } : (n: int) -> "ok" @@ -731,9 +740,9 @@ function user.test_map_pure() -> string throws never { } } } -lambda user.test_map_pure { +lambda user.test_map_it_pure { } -function user.test_map_throwing() -> string throws string { +function user.test_map_it_throwing() -> string throws string { { : string map_it(1, (n: int) -> string { ... }) : string (n: int) -> string { ... } : (n: int) -> never throws string @@ -742,7 +751,7 @@ function user.test_map_throwing() -> string throws string { } } } -lambda user.test_map_throwing { +lambda user.test_map_it_throwing { } function user.apply(f: () -> int throws __throws_f) -> int throws __throws_f { { : int @@ -874,6 +883,7 @@ function user.test_explicit_throws_wider_than_body() -> int throws string { } f() : 42 } + ?? 547..554: extraneous throws declaration: string } lambda user.test_explicit_throws_wider_than_body { ?? 547..554: extraneous throws declaration: string @@ -951,10 +961,40 @@ function user.test_throws_never_but_throws() -> int throws never { } f() : never } + !! 213..219: throws contract violation: `never` is missing string } lambda user.test_throws_never_but_throws { !! 213..219: throws contract violation: `never` is missing string } +function user.test_typed_lambda_throws_mismatch() -> int throws int { + { : never + let f = : () -> never throws int + () -> int { ... } : () -> never throws int + { + throw 42 + } + f() : never + } + !! 419..442: type mismatch: expected () -> int throws string, got () -> never throws int +} +lambda user.test_typed_lambda_throws_mismatch { +} +function user.test_typed_lambda_annotation_underdeclares_body() -> int throws string { + { : never + let f = : () -> never throws string + () -> int throws string { ... } : () -> never throws string + { + throw 42 + } + f() : never + } + !! 654..661: throws contract violation: `string` is missing int + ?? 654..661: extraneous throws declaration: string +} +lambda user.test_typed_lambda_annotation_underdeclares_body { + !! 654..661: throws contract violation: `string` is missing int + ?? 654..661: extraneous throws declaration: string +} function user.apply_with_arg(x: int, f: (int) -> int throws __throws_f) -> int throws __throws_f { { : int f(x) : int @@ -1062,6 +1102,30 @@ lambda user.test_nested_both_throw { } lambda user.test_nested_both_throw { } +function user.optional_call_rethrows(cb: (() -> int throws string)?) -> int? throws string { + { : int? + cb?.() : int? + } +} +function user.optional_call_caught(cb: (() -> int throws string)?) -> int? throws never { + { : int? | null + catch (cb?.() : int?) : unknown + catch (e) + _ => + null : null + } +} +function user.test_optional_call_with_throwing_callback() -> int? throws string { + { : int? + optional_call_rethrows(() -> int { ... }) : int? + () -> int { ... } : () -> never throws string + { + throw "boom" + } + } +} +lambda user.test_optional_call_with_throwing_callback { +} function user.make_pure() -> () -> int throws never { { : never return : () -> 42 @@ -1147,6 +1211,7 @@ function user.test_stored_local_shape_mismatch() -> null throws never { null : null } !! 1283..1311: expected 1 argument(s), got 0 + !! 1283..1311: type mismatch: expected (int) -> null, got () -> never throws string } lambda user.test_stored_local_shape_mismatch { } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index 8d838111aa..c555f349b5 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -2,54 +2,6 @@ source: crates/baml_tests/src/generated_tests.rs --- === COMPILER2 DIAGNOSTICS === - [validation] Error: Duplicate function `may_fail` - ╭─[ catch_absorbs_throws.baml:3:10 ] - │ - 3 │ function may_fail(x: int) -> string throws string { - │ ────┬─── - │ ╰───── first defined as function here - │ - ├─[ fn_decl_throws.baml:4:10 ] - │ - 4 │ function may_fail(x: int) -> string throws string { - │ ────┬─── - │ ╰───── duplicate function definition - │ - │ Note: Error code: E0011 -───╯ - - [validation] Error: Duplicate function `test_map_pure` - ╭─[ array_map_throws.baml:4:10 ] - │ - 4 │ function test_map_pure() -> int[] { - │ ──────┬────── - │ ╰──────── first defined as function here - │ - ├─[ hof_rethrows.baml:37:10 ] - │ - 37 │ function test_map_pure() -> string { - │ ──────┬────── - │ ╰──────── duplicate function definition - │ - │ Note: Error code: E0011 -────╯ - - [validation] Error: Duplicate function `test_map_throwing` - ╭─[ array_map_throws.baml:10:10 ] - │ - 10 │ function test_map_throwing() -> int[] { - │ ────────┬──────── - │ ╰────────── first defined as function here - │ - ├─[ hof_rethrows.baml:41:10 ] - │ - 41 │ function test_map_throwing() -> string { - │ ────────┬──────── - │ ╰────────── duplicate function definition - │ - │ Note: Error code: E0011 -────╯ - [type] Warning: extraneous throws declaration: string ╭─[ lambda_throws_explicit.baml:17:27 ] │ @@ -70,6 +22,56 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ───╯ + [type] Error: type mismatch: expected () -> int throws string, got () -> never throws int + ╭─[ lambda_throws_violation.baml:11:35 ] + │ + 11 │ let f: () -> int throws string = () -> int { throw 42 } + │ ───────────┬─────────── + │ ╰───────────── type mismatch: expected () -> int throws string, got () -> never throws int + │ + │ Note: Error code: E0001 +────╯ + + [type] Warning: extraneous throws declaration: string + ╭─[ lambda_throws_violation.baml:17:52 ] + │ + 17 │ let f: () -> int throws string = () -> int throws string { throw 42 } + │ ───┬─── + │ ╰───── extraneous throws declaration: string + │ + │ Note: Error code: E0001 +────╯ + + [type] Warning: extraneous throws declaration: string + ╭─[ lambda_throws_violation.baml:17:52 ] + │ + 17 │ let f: () -> int throws string = () -> int throws string { throw 42 } + │ ───┬─── + │ ╰───── extraneous throws declaration: string + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `string` is missing int + ╭─[ lambda_throws_violation.baml:17:52 ] + │ + 17 │ let f: () -> int throws string = () -> int throws string { throw 42 } + │ ───┬─── + │ ╰───── throws contract violation: `string` is missing int + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `string` is missing int + ╭─[ lambda_throws_violation.baml:17:52 ] + │ + 17 │ let f: () -> int throws string = () -> int throws string { throw 42 } + │ ───┬─── + │ ╰───── throws contract violation: `string` is missing int + │ + │ Note: Error code: E0001 +────╯ + [type] Warning: unreachable code: 1 statement(s) after diverging statement ╭─[ nested_lambda_throws.baml:1:1 ] │ @@ -120,6 +122,16 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ + [type] Error: type mismatch: expected (int) -> null, got () -> never throws string + ╭─[ stored_callback_enforcement.baml:42:25 ] + │ + 42 │ let f: (int) -> null = () -> null { throw "oops" } + │ ──────────────┬───────────── + │ ╰─────────────── type mismatch: expected (int) -> null, got () -> never throws string + │ + │ Note: Error code: E0001 +────╯ + [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type ╭─[ stored_callback_enforcement.baml:64:25 ] │ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 1bbe3effc3..b473869a17 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -31,7 +31,7 @@ function user.apply_and_throw(f: () -> int) -> int { throw } -function user.apply_explicit(f: () -> int) -> int { +function user.apply_explicit(f: () -> int throws string) -> int { load_var f call_indirect return @@ -74,13 +74,13 @@ function user.apply_inner(f: () -> int) -> int { return } -function user.apply_may_throw_class(f: () -> int) -> int { +function user.apply_may_throw_class(f: () -> int throws ApiError) -> int { load_var f call_indirect return } -function user.apply_may_throw_enum(f: () -> int) -> int { +function user.apply_may_throw_enum(f: () -> int throws ErrorKind) -> int { load_var f call_indirect return @@ -98,7 +98,7 @@ function user.apply_pure_only(f: () -> int) -> int { return } -function user.apply_throwing(f: () -> int) -> int { +function user.apply_throwing(f: () -> int throws string) -> int { load_var f call_indirect return @@ -129,7 +129,7 @@ function user.apply_with_many_args(a: int, b: string, f: (int, string) -> string function user.caller() -> string { load_const 1 - call user.may_fail + call user.declared_may_fail return } @@ -228,6 +228,22 @@ function user.catch_mixed_errors(x: int) -> string { return } +function user.caught_may_fail(x: int) -> string { + load_var x + load_const 0 + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_const "ok" + return + + L1: + load_const "zero" + throw +} + function user.compose(f: (void) -> void, g: (void) -> void) -> (void) -> void { load_var ?1 make_cell @@ -241,6 +257,22 @@ function user.compose(f: (void) -> void, g: (void) -> void) -> (void) -> void { return } +function user.declared_may_fail(x: int) -> string { + load_var x + load_const 0 + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_const "ok" + return + + L1: + load_const "zero" + throw +} + function user.helper_with_body_throw() -> int { load_const "helper boom" throw @@ -301,7 +333,7 @@ function user.make_stored_closure_good() -> () -> int { return } -function user.make_stored_closure_with_throws() -> () -> int { +function user.make_stored_closure_with_throws() -> () -> int throws string { make_closure ., 0 return } @@ -314,7 +346,7 @@ function user.make_stored_throwing_handler() -> StoredThrowingHandler { return } -function user.make_thrower() -> () -> int { +function user.make_thrower() -> () -> int throws string { make_closure ., 0 return } @@ -334,20 +366,46 @@ function user.map_it(x: int, f: (int) -> string) -> string { return } -function user.may_fail(x: int) -> string { - load_var x - load_const 0 +function user.optional_call_caught(cb: () -> int throws string?) -> int? { + load_var cb + load_const null cmp_op == pop_jump_if_false L0 jump L1 L0: - load_const "ok" + load_var cb + call_indirect + jump L2 + load_var e + throw_if_panic + load_const null + jump L2 + + L1: + load_const null + + L2: return +} + +function user.optional_call_rethrows(cb: () -> int throws string?) -> int? { + load_var cb + load_const null + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_var cb + call_indirect + jump L2 L1: - load_const "zero" - throw + load_const null + + L2: + return } function user.run_pure(f: () -> int) -> int { @@ -377,7 +435,7 @@ function user.run_two(f: () -> int, g: () -> int) -> int { function user.safe_caller() -> string { load_const 1 - call user.may_fail + call user.declared_may_fail jump L0 load_var e throw_if_panic @@ -387,7 +445,7 @@ function user.safe_caller() -> string { return } -function user.takes_throwing(f: () -> int) -> int { +function user.takes_throwing(f: () -> int throws string) -> int { load_var f call_indirect return @@ -441,9 +499,29 @@ function user.test_apply_with_helper_throwing() -> int { return } +function user.test_array_map_pure() -> int[] { + load_const 1 + load_const 2 + load_const 3 + alloc_array 3 + make_closure ., 0 + call baml.Array.map + return +} + +function user.test_array_map_throwing() -> int[] { + load_const 1 + load_const 2 + load_const 3 + alloc_array 3 + make_closure ., 0 + call baml.Array.map + return +} + function user.test_catch_all() -> string { load_const 1 - call user.may_fail + call user.caught_may_fail jump L0 load_var e throw_if_panic @@ -455,7 +533,7 @@ function user.test_catch_all() -> string { function user.test_catch_rethrow() -> string { load_const 1 - call user.may_fail + call user.caught_may_fail jump L0 load_var e throw_if_panic @@ -468,7 +546,7 @@ function user.test_catch_rethrow() -> string { function user.test_catch_string() -> string { load_const 1 - call user.may_fail + call user.caught_may_fail jump L2 load_var e type_tag @@ -597,16 +675,16 @@ function user.test_many_args_pure() -> string { return } -function user.test_map_pure() -> string { +function user.test_map_it_pure() -> string { load_const 1 - make_closure ., 0 + make_closure ., 0 call user.map_it return } -function user.test_map_throwing() -> string { +function user.test_map_it_throwing() -> string { load_const 1 - make_closure ., 0 + make_closure ., 0 call user.map_it return } @@ -652,7 +730,13 @@ function user.test_nested_outer_throws() -> int { function user.test_no_catch() -> string { load_const 1 - call user.may_fail + call user.caught_may_fail + return +} + +function user.test_optional_call_with_throwing_callback() -> int? { + make_closure ., 0 + call user.optional_call_rethrows return } @@ -779,6 +863,18 @@ function user.test_type_alias_mixed() -> int { return } +function user.test_typed_lambda_annotation_underdeclares_body() -> int { + make_closure ., 0 + call_indirect + return +} + +function user.test_typed_lambda_throws_mismatch() -> int { + make_closure ., 0 + call_indirect + return +} + function user.test_use_pure() -> int { call user.make_pure call_indirect @@ -1011,19 +1107,25 @@ function user.throw_various_errors(x: int) -> string { throw } -function user.use_class_thrower(f: () -> int) -> int { +function user.use_alias_pure(f: () -> int) -> int { + load_var f + call_indirect + return +} + +function user.use_class_thrower(f: () -> int throws ApiError) -> int { load_var f call_indirect return } -function user.use_enum_thrower(f: () -> int) -> int { +function user.use_enum_thrower(f: () -> int throws ErrorKind) -> int { load_var f call_indirect return } -function user.use_mixed_thrower(f: () -> int) -> int { +function user.use_mixed_thrower(f: () -> int throws ErrorKind | ApiError | string) -> int { load_var f call_indirect return diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__array_map_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__array_map_throws.snap index 799895a7f7..244cffafe9 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__array_map_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__array_map_throws.snap @@ -4,7 +4,7 @@ source: crates/baml_tests/src/generated_tests.rs // === Array.map with throwing callback === // Pure map callback -function test_map_pure() -> int[] { +function test_array_map_pure() -> int[] { let items: int[] = [1, 2, 3] items .map( @@ -15,7 +15,7 @@ function test_map_pure() -> int[] { } // Throwing map callback -function test_map_throwing() -> int[] { +function test_array_map_throwing() -> int[] { let items: int[] = [1, 2, 3] items .map( diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__catch_absorbs_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__catch_absorbs_throws.snap index 3dd5257123..72e419f638 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__catch_absorbs_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__catch_absorbs_throws.snap @@ -2,4 +2,4 @@ source: crates/baml_tests/src/generated_tests.rs --- === STRONG AST ERROR === -Expected token/node of kind LLM_FUNCTION_BODY or EXPR_FUNCTION_BODY, but found THROWS_CLAUSE at catch_absorbs_throws.baml:3:36 +Expected token/node of kind LLM_FUNCTION_BODY or EXPR_FUNCTION_BODY, but found THROWS_CLAUSE at catch_absorbs_throws.baml:3:43 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_decl_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_decl_throws.snap index dd5bc9a100..e217eab6bc 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_decl_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__fn_decl_throws.snap @@ -2,4 +2,4 @@ source: crates/baml_tests/src/generated_tests.rs --- === STRONG AST ERROR === -Expected token/node of kind LLM_FUNCTION_BODY or EXPR_FUNCTION_BODY, but found THROWS_CLAUSE at fn_decl_throws.baml:4:36 +Expected token/node of kind LLM_FUNCTION_BODY or EXPR_FUNCTION_BODY, but found THROWS_CLAUSE at fn_decl_throws.baml:4:45 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_rethrows.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_rethrows.snap index 6b07e7e834..2a811a0f12 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_rethrows.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__hof_rethrows.snap @@ -70,7 +70,7 @@ function map_it(x: int, f: (int) -> string) -> string { f(x) } -function test_map_pure() -> string { +function test_map_it_pure() -> string { map_it( 1, (n: int) -> string { @@ -79,7 +79,7 @@ function test_map_pure() -> string { ) } -function test_map_throwing() -> string { +function test_map_it_throwing() -> string { map_it( 1, (n: int) -> string { diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_violation.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_violation.snap index 1b6857dba1..f90d9d9adb 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_violation.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__lambda_throws_violation.snap @@ -1,12 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs --- -// === Lambda throws contract violations (should produce errors) === - -// Lambda with explicit throws never but body throws - should error -function test_throws_never_but_throws() -> int { - let f = () -> int throws never { - throw "boom" - } - f() -} +=== STRONG AST ERROR === +Unexpected additional element at lambda_throws_violation.baml:11:19 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__optional_call_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__optional_call_throws.snap new file mode 100644 index 0000000000..e2a460ec29 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__optional_call_throws.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +Unexpected additional element at optional_call_throws.baml:3:47 diff --git a/baml_language/crates/baml_type/src/lib.rs b/baml_language/crates/baml_type/src/lib.rs index 4091306899..9318b46185 100644 --- a/baml_language/crates/baml_type/src/lib.rs +++ b/baml_language/crates/baml_type/src/lib.rs @@ -555,11 +555,20 @@ impl Ty { Ok(()) } // All other variants are fine at runtime - Ty::Function { params, ret, .. } => { + Ty::Function { + params, + ret, + throws, + .. + } => { for p in params { p.validate_runtime()?; } - ret.validate_runtime() + ret.validate_runtime()?; + if let Some(throws) = throws { + throws.validate_runtime()?; + } + Ok(()) } Ty::Future(_, _) => { Err("Future type should not reach runtime (must be awaited)".to_string()) @@ -621,12 +630,21 @@ impl fmt::Display for Ty { types.iter().map(std::string::ToString::to_string).collect(); write!(f, "{}", parts.join(" | ")) } - Ty::Function { params, ret, .. } => { + Ty::Function { + params, + ret, + throws, + .. + } => { let param_strs: Vec = params .iter() .map(std::string::ToString::to_string) .collect(); - write!(f, "({}) -> {}", param_strs.join(", "), ret) + write!(f, "({}) -> {}", param_strs.join(", "), ret)?; + if let Some(throws) = throws { + write!(f, " throws {}", throws)?; + } + Ok(()) } Ty::Void { .. } => write!(f, "void"), Ty::WatchAccessor(inner, _) => write!(f, "{inner}.$watch"), @@ -788,6 +806,31 @@ mod tests { assert_eq!(ty.to_string(), "(int | string)[]"); } + #[test] + fn test_display_function_includes_throws() { + let ty = Ty::Function { + params: vec![ty_int()], + ret: Box::new(ty_string()), + throws: Some(Box::new(ty_bool())), + attr: TyAttr::default(), + }; + assert_eq!(ty.to_string(), "(int) -> string throws bool"); + } + + #[test] + fn test_validate_runtime_checks_function_throws() { + let ty = Ty::Function { + params: vec![], + ret: Box::new(ty_int()), + throws: Some(Box::new(Ty::Future( + Box::new(ty_string()), + TyAttr::default(), + ))), + attr: TyAttr::default(), + }; + assert!(ty.validate_runtime().is_err()); + } + #[test] fn test_validate_runtime_rejects_compiler_types() { assert!( From e45b28d6e05dcc75e686e5f64aadccafa96c9e81 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 13:12:39 -0500 Subject: [PATCH 08/26] Fix lambda generic scope and optional function-shape matching --- .../baml_compiler2_tir/src/inference.rs | 70 ++++++++++++++++--- .../src/throws_semantics.rs | 46 +++++++++++- 2 files changed, 105 insertions(+), 11 deletions(-) diff --git a/baml_language/crates/baml_compiler2_tir/src/inference.rs b/baml_language/crates/baml_compiler2_tir/src/inference.rs index 361cc42547..83dd242313 100644 --- a/baml_language/crates/baml_compiler2_tir/src/inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/inference.rs @@ -74,6 +74,8 @@ pub enum MemberResolution<'db> { pub struct ScopeInference<'db> { /// Type of every expression within this scope (NOT nested child scopes). expressions: FxHashMap, + /// Fully resolved type aliases visible to this scope, including dependency exports. + aliases: HashMap, /// Binding types: the type a variable is bound to after widening/annotation. /// May differ from the initializer expression type (e.g. `let x = 1` has /// expression type `Literal(1, Fresh)` but binding type `int`). @@ -172,16 +174,13 @@ impl<'db> ScopeInference<'db> { package_id: PackageId<'db>, body: &baml_compiler2_ast::ExprBody, ) -> crate::ty::Ty { - let pkg_items = baml_compiler2_ppir::package_items(db, package_id); - let aliases = collect_type_aliases(db, pkg_items); - let facts = crate::effective_throws::collect_effective_throws( db, package_id, body, &self.expressions, &self.catch_residual_throws, - &aliases, + &self.aliases, false, false, ); @@ -249,6 +248,19 @@ fn find_lambda_by_span<'a>( None } +fn extend_generic_params_with_enclosing_lambdas( + generic_params: &mut Vec, + body: &ExprBody, + source_map: &AstSourceMap, + enclosing_lambda_spans: &[TextRange], +) { + for lambda_span in enclosing_lambda_spans.iter().rev() { + if let Some((lambda_def, _, _, _)) = find_lambda_by_span(body, source_map, *lambda_span) { + generic_params.extend(lambda_def.generic_params.iter().cloned()); + } + } +} + /// Per-scope type inference — the primary Salsa query for type checking. /// /// Returns expression types for a single scope. Lambda/closure bodies are @@ -288,7 +300,8 @@ pub fn infer_scope_types<'db>( } } let context = InferContext::new(db, scope_id); - let mut builder = TypeInferenceBuilder::new(context, res_ctx, pkg_id, scope_id, aliases); + let mut builder = + TypeInferenceBuilder::new(context, res_ctx, pkg_id, scope_id, aliases.clone()); // Dispatch based on scope kind match &scope.kind { @@ -462,9 +475,13 @@ pub fn infer_scope_types<'db>( } // Walk ancestors to find a Function or Let scope that has a body. + let mut enclosing_lambda_spans = Vec::new(); 'ancestor_walk: for ancestor_fsi in index.ancestor_scopes(file_scope) { let ancestor_scope = &index.scopes[ancestor_fsi.index() as usize]; match &ancestor_scope.kind { + ScopeKind::Lambda => { + enclosing_lambda_spans.push(ancestor_scope.range); + } ScopeKind::Function => { // Find the function by span + name in the item tree for func_data in item_tree.functions.values() { @@ -483,8 +500,33 @@ pub fn infer_scope_types<'db>( if let Some((func_def, lambda_body, _lambda_sm, _lambda_expr_id)) = find_lambda_by_span(func_body, func_sm, lambda_span) { - // Seed builder with lambda params - let generic_params: Vec = func_def.generic_params.clone(); + // Seed builder with the full generic scope: + // enclosing class/function generics, then outer lambda + // generics, then the lambda's own generic params. + let mut generic_params = func_data.generic_params.clone(); + if let Some(parent_idx) = ancestor_scope.parent { + let parent = &index.scopes[parent_idx.index() as usize]; + if matches!(parent.kind, ScopeKind::Class) { + if let Some(class_name) = &parent.name { + for class_data in item_tree.classes.values() { + if class_data.name == *class_name { + let mut merged = + class_data.generic_params.clone(); + merged.extend(generic_params); + generic_params = merged; + break; + } + } + } + } + } + extend_generic_params_with_enclosing_lambdas( + &mut generic_params, + func_body, + func_sm, + &enclosing_lambda_spans, + ); + generic_params.extend(func_def.generic_params.iter().cloned()); builder.set_generic_params(generic_params.clone()); for param in &func_def.params { let param_ty = param @@ -540,8 +582,17 @@ pub fn infer_scope_types<'db>( if let Some((func_def, lambda_body, _lambda_sm, _lambda_expr_id)) = find_lambda_by_span(let_body, &let_sm, lambda_span) { - // Seed builder with lambda params - let generic_params: Vec = func_def.generic_params.clone(); + // Top-level lets have no own generic params, but nested + // lambdas can still inherit generic params from enclosing + // lambdas inside the initializer. + let mut generic_params = Vec::new(); + extend_generic_params_with_enclosing_lambdas( + &mut generic_params, + let_body, + &let_sm, + &enclosing_lambda_spans, + ); + generic_params.extend(func_def.generic_params.iter().cloned()); builder.set_generic_params(generic_params.clone()); for param in &func_def.params { let param_ty = param @@ -630,6 +681,7 @@ pub fn infer_scope_types<'db>( ScopeInference { expressions, + aliases, bindings, resolutions, exhaustive_matches, diff --git a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs index 5a7151fb38..d9a195d3cf 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs @@ -185,8 +185,15 @@ pub(crate) fn function_shape_matches_ignoring_outer_throws( expected: &Ty, aliases: &HashMap, ) -> bool { - let got_resolved = resolve_alias_chain(got, aliases); - let expected_resolved = resolve_alias_chain(expected, aliases); + let mut got_resolved = resolve_alias_chain(got, aliases); + let mut expected_resolved = resolve_alias_chain(expected, aliases); + + while let (Ty::Optional(got_inner, _), Ty::Optional(expected_inner, _)) = + (&got_resolved, &expected_resolved) + { + got_resolved = resolve_alias_chain(got_inner.as_ref(), aliases); + expected_resolved = resolve_alias_chain(expected_inner.as_ref(), aliases); + } match (got_resolved, expected_resolved) { ( @@ -248,3 +255,38 @@ fn declared_covers_fact( _ => normalize::is_subtype_of(fact, declared, aliases), } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::function_shape_matches_ignoring_outer_throws; + use crate::ty::{PrimitiveType, Ty, TyAttr}; + + #[test] + fn optional_function_shapes_match_ignoring_outer_throws() { + let aliases = HashMap::new(); + let got = Ty::Optional( + Box::new(Ty::Function { + params: vec![(None, Ty::Primitive(PrimitiveType::Int, TyAttr::default()))], + ret: Box::new(Ty::Primitive(PrimitiveType::String, TyAttr::default())), + throws: Box::new(Ty::Primitive(PrimitiveType::String, TyAttr::default())), + attr: TyAttr::default(), + }), + TyAttr::default(), + ); + let expected = Ty::Optional( + Box::new(Ty::Function { + params: vec![(None, Ty::Primitive(PrimitiveType::Int, TyAttr::default()))], + ret: Box::new(Ty::Primitive(PrimitiveType::String, TyAttr::default())), + throws: Box::new(Ty::Primitive(PrimitiveType::Bool, TyAttr::default())), + attr: TyAttr::default(), + }), + TyAttr::default(), + ); + + assert!(function_shape_matches_ignoring_outer_throws( + &got, &expected, &aliases + )); + } +} From 32fac735e57fbe4e477c378b1db38be37a32b9b2 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 13:17:41 -0500 Subject: [PATCH 09/26] Clean up lambda throws diagnostics and parenthesize optional function types --- .../crates/baml_compiler2_hir/src/builder.rs | 24 +++++++++++++-- .../crates/baml_compiler2_tir/src/builder.rs | 20 ++++--------- .../crates/baml_lsp2_actions/src/utils.rs | 29 +++++++++++++++---- ...tests__function_type_throws__04_5_mir.snap | 8 ++--- ...l_tests__function_type_throws__04_tir.snap | 4 --- ..._function_type_throws__05_diagnostics.snap | 20 ------------- ...sts__function_type_throws__06_codegen.snap | 4 +-- baml_language/crates/baml_type/src/lib.rs | 15 ++++++++-- 8 files changed, 71 insertions(+), 53 deletions(-) diff --git a/baml_language/crates/baml_compiler2_hir/src/builder.rs b/baml_language/crates/baml_compiler2_hir/src/builder.rs index 87357fb0c1..c8193399fc 100644 --- a/baml_language/crates/baml_compiler2_hir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_hir/src/builder.rs @@ -1102,8 +1102,28 @@ impl<'db> SemanticIndexBuilder<'db> { ast::TypeExpr::Never { .. } => "never".to_string(), ast::TypeExpr::Uint8Array { .. } => "uint8array".to_string(), ast::TypeExpr::Media { kind, .. } => kind.to_string(), - ast::TypeExpr::Optional { inner, .. } => format!("{}?", Self::render_type_expr(inner)), - ast::TypeExpr::List { inner, .. } => format!("{}[]", Self::render_type_expr(inner)), + ast::TypeExpr::Optional { inner, .. } => { + let rendered = Self::render_type_expr(inner); + if matches!( + **inner, + ast::TypeExpr::Union { .. } | ast::TypeExpr::Function { .. } + ) { + format!("({rendered})?") + } else { + format!("{rendered}?") + } + } + ast::TypeExpr::List { inner, .. } => { + let rendered = Self::render_type_expr(inner); + if matches!( + **inner, + ast::TypeExpr::Union { .. } | ast::TypeExpr::Function { .. } + ) { + format!("({rendered})[]") + } else { + format!("{rendered}[]") + } + } ast::TypeExpr::Map { key, value, .. } => format!( "map<{}, {}>", Self::render_type_expr(key), diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index af4eb0e392..72fb9f129d 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -1811,7 +1811,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.report_throws_contract_diff_at_span(&declared_ty, &effective, span); } - fn lower_lambda_throws_annotation( + fn lower_lambda_throws_annotation_silently( &mut self, throws_annotation: &baml_compiler2_ast::SpannedTypeExpr, generic_params: &[Name], @@ -1825,9 +1825,10 @@ impl<'db> TypeInferenceBuilder<'db> { generic_params, &mut diags, ); - for diag in diags { - self.context.report_at_span(diag, throws_annotation.span); - } + // Lambda scopes validate explicit throws annotations separately. + // Keep parent-scope lambda finalization silent so the same span is not + // reported twice in both the enclosing function scope and the lambda scope. + drop(diags); declared } @@ -1840,16 +1841,7 @@ impl<'db> TypeInferenceBuilder<'db> { generic_params: &[Name], ) -> Ty { let throws_ty = if let Some(throws_annotation) = &func_def.throws { - let declared_throws_ty = - self.lower_lambda_throws_annotation(throws_annotation, generic_params); - let effective_throws_facts = - throws_semantics::flatten_ty_to_facts(&effective_throws_ty); - self.report_throws_contract_diff_at_span( - &declared_throws_ty, - &effective_throws_facts, - throws_annotation.span, - ); - declared_throws_ty + self.lower_lambda_throws_annotation_silently(throws_annotation, generic_params) } else { effective_throws_ty }; diff --git a/baml_language/crates/baml_lsp2_actions/src/utils.rs b/baml_language/crates/baml_lsp2_actions/src/utils.rs index 16a4c5be88..19469b9a54 100644 --- a/baml_language/crates/baml_lsp2_actions/src/utils.rs +++ b/baml_language/crates/baml_lsp2_actions/src/utils.rs @@ -128,13 +128,25 @@ pub fn display_ty(ty: &Ty) -> String { PrimitiveType::Pdf => "pdf".to_string(), PrimitiveType::Uint8Array => "uint8array".to_string(), }, - Ty::List(inner, _) => format!("{}[]", display_ty(inner)), + Ty::List(inner, _) => { + let rendered = display_ty(inner); + if matches!(inner.as_ref(), Ty::Union(..) | Ty::Function { .. }) { + format!("({rendered})[]") + } else { + format!("{rendered}[]") + } + } Ty::Map(k, v, _) => format!("map<{}, {}>", display_ty(k), display_ty(v)), Ty::EvolvingList(inner, _) => { if matches!(**inner, Ty::Never { .. }) { "_[]".to_string() } else { - format!("{}[]", display_ty(inner)) + let rendered = display_ty(inner); + if matches!(inner.as_ref(), Ty::Union(..) | Ty::Function { .. }) { + format!("({rendered})[]") + } else { + format!("{rendered}[]") + } } } Ty::EvolvingMap(k, v, _) => { @@ -148,7 +160,14 @@ pub fn display_ty(ty: &Ty) -> String { let parts: Vec<_> = members.iter().map(display_ty).collect(); parts.join(" | ") } - Ty::Optional(inner, _) => format!("{}?", display_ty(inner)), + Ty::Optional(inner, _) => { + let rendered = display_ty(inner); + if matches!(inner.as_ref(), Ty::Union(..) | Ty::Function { .. }) { + format!("({rendered})?") + } else { + format!("{rendered}?") + } + } Ty::Literal(lit, _freshness, _) => lit.to_string(), Ty::Function { params, @@ -206,7 +225,7 @@ pub fn display_type_expr(te: &TypeExpr) -> String { TypeExpr::Media { kind, .. } => format!("{kind:?}").to_lowercase(), TypeExpr::Optional { inner, .. } => { let s = display_type_expr(inner); - if matches!(**inner, TypeExpr::Union { .. }) { + if matches!(**inner, TypeExpr::Union { .. } | TypeExpr::Function { .. }) { format!("({s})?") } else { format!("{s}?") @@ -214,7 +233,7 @@ pub fn display_type_expr(te: &TypeExpr) -> String { } TypeExpr::List { inner, .. } => { let s = display_type_expr(inner); - if matches!(**inner, TypeExpr::Union { .. }) { + if matches!(**inner, TypeExpr::Union { .. } | TypeExpr::Function { .. }) { format!("({s})[]") } else { format!("{s}[]") diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index 6754580aca..de02e5de6e 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -3199,10 +3199,10 @@ fn ., 1)>() -> null { } } -fn user.optional_call_caught(cb: () -> int throws string?) -> int? { +fn user.optional_call_caught(cb: (() -> int throws string)?) -> int? { // Locals: let _0: int? // _0 // return - let _1: () -> int throws string? // cb // param + let _1: (() -> int throws string)? // cb // param let _2: unknown // e let _3: bool @@ -3250,10 +3250,10 @@ fn user.optional_call_caught(cb: () -> int throws string?) -> int? { } } -fn user.optional_call_rethrows(cb: () -> int throws string?) -> int? { +fn user.optional_call_rethrows(cb: (() -> int throws string)?) -> int? { // Locals: let _0: int? // _0 // return - let _1: () -> int throws string? // cb // param + let _1: (() -> int throws string)? // cb // param let _2: bool bb0: { diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 1623f8b300..31ec4435b9 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -883,7 +883,6 @@ function user.test_explicit_throws_wider_than_body() -> int throws string { } f() : 42 } - ?? 547..554: extraneous throws declaration: string } lambda user.test_explicit_throws_wider_than_body { ?? 547..554: extraneous throws declaration: string @@ -961,7 +960,6 @@ function user.test_throws_never_but_throws() -> int throws never { } f() : never } - !! 213..219: throws contract violation: `never` is missing string } lambda user.test_throws_never_but_throws { !! 213..219: throws contract violation: `never` is missing string @@ -988,8 +986,6 @@ function user.test_typed_lambda_annotation_underdeclares_body() -> int throws st } f() : never } - !! 654..661: throws contract violation: `string` is missing int - ?? 654..661: extraneous throws declaration: string } lambda user.test_typed_lambda_annotation_underdeclares_body { !! 654..661: throws contract violation: `string` is missing int diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index c555f349b5..290499edb6 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -42,26 +42,6 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ - [type] Warning: extraneous throws declaration: string - ╭─[ lambda_throws_violation.baml:17:52 ] - │ - 17 │ let f: () -> int throws string = () -> int throws string { throw 42 } - │ ───┬─── - │ ╰───── extraneous throws declaration: string - │ - │ Note: Error code: E0001 -────╯ - - [type] Error: throws contract violation: `string` is missing int - ╭─[ lambda_throws_violation.baml:17:52 ] - │ - 17 │ let f: () -> int throws string = () -> int throws string { throw 42 } - │ ───┬─── - │ ╰───── throws contract violation: `string` is missing int - │ - │ Note: Error code: E0001 -────╯ - [type] Error: throws contract violation: `string` is missing int ╭─[ lambda_throws_violation.baml:17:52 ] │ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index b473869a17..9d005516e5 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -366,7 +366,7 @@ function user.map_it(x: int, f: (int) -> string) -> string { return } -function user.optional_call_caught(cb: () -> int throws string?) -> int? { +function user.optional_call_caught(cb: (() -> int throws string)?) -> int? { load_var cb load_const null cmp_op == @@ -389,7 +389,7 @@ function user.optional_call_caught(cb: () -> int throws string?) -> int? { return } -function user.optional_call_rethrows(cb: () -> int throws string?) -> int? { +function user.optional_call_rethrows(cb: (() -> int throws string)?) -> int? { load_var cb load_const null cmp_op == diff --git a/baml_language/crates/baml_type/src/lib.rs b/baml_language/crates/baml_type/src/lib.rs index 9318b46185..53e2c3e1ba 100644 --- a/baml_language/crates/baml_type/src/lib.rs +++ b/baml_language/crates/baml_type/src/lib.rs @@ -611,14 +611,14 @@ impl fmt::Display for Ty { Ty::Opaque(tn, _) => write!(f, "{tn}"), Ty::TypeAlias(tn, _) => write!(f, "{tn}"), Ty::Optional(inner, _) => { - if matches!(inner.as_ref(), Ty::Union(..)) { + if matches!(inner.as_ref(), Ty::Union(..) | Ty::Function { .. }) { write!(f, "({inner})?") } else { write!(f, "{inner}?") } } Ty::List(inner, _) => { - if matches!(inner.as_ref(), Ty::Union(..)) { + if matches!(inner.as_ref(), Ty::Union(..) | Ty::Function { .. }) { write!(f, "({inner})[]") } else { write!(f, "{inner}[]") @@ -806,6 +806,17 @@ mod tests { assert_eq!(ty.to_string(), "(int | string)[]"); } + #[test] + fn test_display_optional_function_parenthesized() { + let ty = Ty::optional(Ty::Function { + params: vec![ty_int()], + ret: Box::new(ty_string()), + throws: Some(Box::new(ty_bool())), + attr: TyAttr::default(), + }); + assert_eq!(ty.to_string(), "((int) -> string throws bool)?"); + } + #[test] fn test_display_function_includes_throws() { let ty = Ty::Function { From cef2fc866bfd10791158d5274933c46ae3b43f99 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 13:31:31 -0500 Subject: [PATCH 10/26] Update MIR and codegen snapshots for function throws display --- .../baml_tests____baml_std____04_5_mir.snap | 16 +++++++------- .../baml_tests____baml_std____06_codegen.snap | 2 +- ...baml_tests____testing_std____04_5_mir.snap | 22 +++++++++---------- ...ml_tests____testing_std____06_codegen.snap | 8 +++---- .../baml_tests__lambda_basic__04_5_mir.snap | 4 ++-- 5 files changed, 26 insertions(+), 26 deletions(-) diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap index f75cfcf7e9..39643f838b 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap @@ -500,10 +500,10 @@ fn baml.llm.Client.build_attempt_with_state(self: baml.llm.Client, planner_state let _2: baml.llm.PlannerState // planner_state // param let _3: baml.llm.ClientType let _4: int - let _5: () -> baml.llm.PrimitiveClient // resolve_fn - let _6: future<() -> baml.llm.PrimitiveClient> + let _5: () -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument // resolve_fn + let _6: future<() -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument> let _7: baml.llm.PrimitiveClient // primitive - let _8: () -> baml.llm.PrimitiveClient + let _8: () -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument let _9: baml.llm.OrchestrationStep let _10: baml.llm.PrimitiveClient let _11: baml.llm.OrchestrationStep[] // steps @@ -1094,10 +1094,10 @@ fn baml.llm.Client.execute_once(self: baml.llm.Client, context: baml.llm.Executi let _3: int // active_delay_ms // param let _4: baml.llm.ClientType let _5: int - let _6: () -> baml.llm.PrimitiveClient // resolve_fn - let _7: future<() -> baml.llm.PrimitiveClient> + let _6: () -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument // resolve_fn + let _7: future<() -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument> let _8: baml.llm.PrimitiveClient // primitive - let _9: () -> baml.llm.PrimitiveClient + let _9: () -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument let _10: type // return_type let _11: string let _12: future @@ -1408,8 +1408,8 @@ fn baml.llm.Client.to_primitive_client(self: baml.llm.Client) -> baml.llm.Primit let _4: void let _5: string let _6: string - let _7: () -> baml.llm.PrimitiveClient - let _8: future<() -> baml.llm.PrimitiveClient> + let _7: () -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument + let _8: future<() -> baml.llm.PrimitiveClient throws baml.errors.InvalidArgument> bb0: { _3 = copy _1.0; diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap index c83060193e..e99fd00a24 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap @@ -784,7 +784,7 @@ function baml.llm.Client.execute_once(self: null, context: void, active_delay_ms return } -function baml.llm.Client.get_constructor(self: null) -> () -> void { +function baml.llm.Client.get_constructor(self: null) -> () -> void throws baml.errors.InvalidArgument { } function baml.llm.Client.to_primitive_client(self: null) -> void { diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap index da82f7bcf5..01b9ab2047 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap @@ -43,7 +43,7 @@ fn testing.TestRegistry.expand_set(self: testing.TestRegistry, name: string) -> let _15: void[] let _16: void[] let _17: null - let _18: (testing.TestCollector) -> null + let _18: (testing.TestCollector) -> null throws unknown let _19: testing.TestSetRegistration let _20: testing.TestCollector let _21: int // elapsed @@ -309,13 +309,13 @@ fn testing.TestCollector.new(prefix: string) -> testing.TestCollector { } } -fn testing.TestCollector.register_test(self: testing.TestCollector, name: string, body: () -> null, runner: (() -> testing.TestReport) -> () -> testing.TestReport?) -> null { +fn testing.TestCollector.register_test(self: testing.TestCollector, name: string, body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> null { // Locals: let _0: null // _0 // return let _1: testing.TestCollector // self // param let _2: string // name // param let _3: () -> null // body // param - let _4: (() -> testing.TestReport) -> () -> testing.TestReport? // runner // param + let _4: ((() -> testing.TestReport) -> () -> testing.TestReport)? // runner // param let _5: string // full_name let _6: bool let _7: string @@ -469,13 +469,13 @@ fn testing.TestCollector.register_test(self: testing.TestCollector, name: string } } -fn testing.TestCollector.register_test_set(self: testing.TestCollector, name: string, collector: (testing.TestCollector) -> null, runner: (() -> testing.TestSetReport) -> () -> testing.TestSetReport?) -> null { +fn testing.TestCollector.register_test_set(self: testing.TestCollector, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> null { // Locals: let _0: null // _0 // return let _1: testing.TestCollector // self // param let _2: string // name // param let _3: (testing.TestCollector) -> null // collector // param - let _4: (() -> testing.TestSetReport) -> () -> testing.TestSetReport? // runner // param + let _4: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)? // runner // param let _5: string // full_name let _6: bool let _7: string @@ -629,11 +629,11 @@ fn testing.TestCollector.register_test_set(self: testing.TestCollector, name: st } } -fn testing.run_test(body: () -> null, runner: (() -> testing.TestReport) -> () -> testing.TestReport?) -> testing.TestReport { +fn testing.run_test(body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> testing.TestReport { // Locals: let _0: testing.TestReport // _0 // return let _1: () -> null // body // param [captured] - let _2: (() -> testing.TestReport) -> () -> testing.TestReport? // runner // param + let _2: ((() -> testing.TestReport) -> () -> testing.TestReport)? // runner // param let _3: () -> testing.TestReport // base_run let _4: () -> testing.TestReport // effective_run let _5: bool @@ -771,9 +771,9 @@ fn testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) -> te let _10: bool let _11: string let _12: testing.TestRegistration - let _13: () -> null + let _13: () -> null throws unknown let _14: testing.TestRegistration - let _15: (() -> testing.TestReport) -> () -> testing.TestReport? + let _15: ((() -> testing.TestReport) -> () -> testing.TestReport)? let _16: testing.TestRegistration let _17: string[] let _18: map @@ -901,11 +901,11 @@ fn testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) -> te } } -fn testing.run_testset(run_children: () -> testing.TestSetReport, runner: (() -> testing.TestSetReport) -> () -> testing.TestSetReport?) -> testing.TestSetReport { +fn testing.run_testset(run_children: () -> testing.TestSetReport, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> testing.TestSetReport { // Locals: let _0: testing.TestSetReport // _0 // return let _1: () -> testing.TestSetReport // run_children // param - let _2: (() -> testing.TestSetReport) -> () -> testing.TestSetReport? // runner // param + let _2: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)? // runner // param let _3: () -> testing.TestSetReport // effective_run let _4: bool let _5: bool diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap index 1573371862..d654398641 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap @@ -67,7 +67,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector { return } -function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: (() -> testing.TestReport) -> () -> testing.TestReport?) -> null { +function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> null { load_var self load_field .prefix load_const "" @@ -191,7 +191,7 @@ function testing.TestCollector.register_test(self: null, name: string, body: () jump L3 } -function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: (() -> testing.TestSetReport) -> () -> testing.TestSetReport?) -> null { +function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> null { load_var self load_field .prefix load_const "" @@ -686,7 +686,7 @@ function testing.TestRegistry.serialize(self: null) -> (testing.SerializedTest | jump L0 } -function testing.run_test(body: () -> null, runner: (() -> testing.TestReport) -> () -> testing.TestReport?) -> testing.TestReport { +function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> testing.TestReport { load_var ?1 make_cell store_var ?1 @@ -721,7 +721,7 @@ function testing.run_test(body: () -> null, runner: (() -> testing.TestReport) - return } -function testing.run_testset(run_children: () -> testing.TestSetReport, runner: (() -> testing.TestSetReport) -> () -> testing.TestSetReport?) -> testing.TestSetReport { +function testing.run_testset(run_children: () -> testing.TestSetReport, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> testing.TestSetReport { load_var runner load_const null cmp_op == diff --git a/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_5_mir.snap index 96e95dd6c3..c155a7aa54 100644 --- a/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/lambda_basic/baml_tests__lambda_basic__04_5_mir.snap @@ -296,8 +296,8 @@ fn ., 1)>(y: int) -> null { fn user.test_throws() -> int { // Locals: let _0: int // _0 // return - let _1: (int) -> int // risky - let _2: (int) -> int + let _1: (int) -> int throws string // risky + let _2: (int) -> int throws string bb0: { _1 = make_closure lambda[0](); From be22d2dfffe77db3494735ec48d91a42d0e24fd5 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Wed, 8 Apr 2026 14:21:23 -0500 Subject: [PATCH 11/26] Update test expectations for typed rethrows changes LSP tests and bytecode snapshots expected stale throws contract violation diagnostics and old optional-function-type formatting that changed with the typed rethrows for higher-order functions work. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../syntax/hover/function_throws.baml | 19 +--------- .../throw/throws_caller_sees_contract.baml | 37 +------------------ .../throw/throws_caller_variant_contract.baml | 28 +------------- .../syntax/throw/throws_enum_exact_match.baml | 19 +--------- .../syntax/throw/throws_enum_extraneous.baml | 19 +--------- .../throw/throws_enum_variant_precise.baml | 19 +--------- .../throw/throws_enum_variant_violation.baml | 19 +--------- .../test_files/syntax/throw/throws_mixed.baml | 19 +--------- ...ode_format__bytecode_display_expanded.snap | 9 +++-- ...bytecode_display_expanded_unoptimized.snap | 9 +++-- ...code_format__bytecode_display_textual.snap | 9 +++-- 11 files changed, 23 insertions(+), 183 deletions(-) diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/hover/function_throws.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/hover/function_throws.baml index 0cb7868ae0..27ce06a6f2 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/hover/function_throws.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/hover/function_throws.baml @@ -9,24 +9,7 @@ function MayFail<[CURSOR](x: int) -> string throws Errors { //---- //- diagnostics -// Error: throws contract violation: `user.Errors` is missing user.Errors.AuthError -// ╭─[ hover_function_throws.baml:6:42 ] -// │ -// 6 │ function MayFail(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── throws contract violation: `user.Errors` is missing user.Errors.AuthError -// │ -// │ Note: Error code: E0001 -// ───╯ -// Warning: extraneous throws declaration: user.Errors -// ╭─[ hover_function_throws.baml:6:42 ] -// │ -// 6 │ function MayFail(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── extraneous throws declaration: user.Errors -// │ -// │ Note: Error code: E0001 -// ───╯ +// // //- on_hover expressions // hover at hover_function_throws.baml:6:17 diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_caller_sees_contract.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_caller_sees_contract.baml index 27374a2924..22e8c2f600 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_caller_sees_contract.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_caller_sees_contract.baml @@ -20,39 +20,4 @@ function Caller(x: int) -> string { //---- //- diagnostics -// Error: throws contract violation: `user.Errors` is missing user.Errors.AuthError -// ╭─[ throw_throws_caller_sees_contract.baml:9:42 ] -// │ -// 9 │ function MayFail(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── throws contract violation: `user.Errors` is missing user.Errors.AuthError -// │ -// │ Note: Error code: E0001 -// ───╯ -// Warning: extraneous throws declaration: user.Errors -// ╭─[ throw_throws_caller_sees_contract.baml:9:42 ] -// │ -// 9 │ function MayFail(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── extraneous throws declaration: user.Errors -// │ -// │ Note: Error code: E0001 -// ───╯ -// Warning: unreachable arm -// ╭─[ throw_throws_caller_sees_contract.baml:16:29 ] -// │ -// 16 │ Errors.NotFoundError => "notfound", -// │ ─────┬──── -// │ ╰────── unreachable arm -// │ -// │ Note: Error code: E0001 -// ────╯ -// Warning: unreachable arm -// ╭─[ throw_throws_caller_sees_contract.baml:17:29 ] -// │ -// 17 │ Errors.InternalError => "internal" -// │ ─────┬──── -// │ ╰────── unreachable arm -// │ -// │ Note: Error code: E0001 -// ────╯ +// diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_caller_variant_contract.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_caller_variant_contract.baml index 2155942985..718686f386 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_caller_variant_contract.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_caller_variant_contract.baml @@ -23,30 +23,4 @@ function CallerGood(x: int) -> string { //---- //- diagnostics -// Error: throws contract violation: `user.Errors` is missing user.Errors.AuthError, user.Errors.NotFoundError -// ╭─[ throw_throws_caller_variant_contract.baml:10:42 ] -// │ -// 10 │ function MayFail(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── throws contract violation: `user.Errors` is missing user.Errors.AuthError, user.Errors.NotFoundError -// │ -// │ Note: Error code: E0001 -// ────╯ -// Warning: extraneous throws declaration: user.Errors -// ╭─[ throw_throws_caller_variant_contract.baml:10:42 ] -// │ -// 10 │ function MayFail(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── extraneous throws declaration: user.Errors -// │ -// │ Note: Error code: E0001 -// ────╯ -// Warning: unreachable arm -// ╭─[ throw_throws_caller_variant_contract.baml:20:29 ] -// │ -// 20 │ Errors.NotFoundError => "notfound" -// │ ─────┬──── -// │ ╰────── unreachable arm -// │ -// │ Note: Error code: E0001 -// ────╯ +// diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_exact_match.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_exact_match.baml index 622a1f3abc..9514d4e9a1 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_exact_match.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_exact_match.baml @@ -16,21 +16,4 @@ function ThrowsEnumExact(x: int) -> string throws Errors { //---- //- diagnostics -// Error: throws contract violation: `user.Errors` is missing user.Errors.AuthError, user.Errors.InternalError, user.Errors.NotFoundError -// ╭─[ throw_throws_enum_exact_match.baml:9:52 ] -// │ -// 9 │ function ThrowsEnumExact(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── throws contract violation: `user.Errors` is missing user.Errors.AuthError, user.Errors.InternalError, user.Errors.NotFoundError -// │ -// │ Note: Error code: E0001 -// ───╯ -// Warning: extraneous throws declaration: user.Errors -// ╭─[ throw_throws_enum_exact_match.baml:9:52 ] -// │ -// 9 │ function ThrowsEnumExact(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── extraneous throws declaration: user.Errors -// │ -// │ Note: Error code: E0001 -// ───╯ +// diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_extraneous.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_extraneous.baml index 96e75c00ed..20b7951273 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_extraneous.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_extraneous.baml @@ -12,21 +12,4 @@ function ThrowsEnumExtraneous(x: int) -> string throws Errors { //---- //- diagnostics -// Error: throws contract violation: `user.Errors` is missing user.Errors.AuthError -// ╭─[ throw_throws_enum_extraneous.baml:9:57 ] -// │ -// 9 │ function ThrowsEnumExtraneous(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── throws contract violation: `user.Errors` is missing user.Errors.AuthError -// │ -// │ Note: Error code: E0001 -// ───╯ -// Warning: extraneous throws declaration: user.Errors -// ╭─[ throw_throws_enum_extraneous.baml:9:57 ] -// │ -// 9 │ function ThrowsEnumExtraneous(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── extraneous throws declaration: user.Errors -// │ -// │ Note: Error code: E0001 -// ───╯ +// diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_variant_precise.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_variant_precise.baml index 68baf9585d..b8faedc8f4 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_variant_precise.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_variant_precise.baml @@ -18,21 +18,4 @@ function ThrowsVariantExact(x: int) -> string throws Errors { //---- //- diagnostics -// Error: throws contract violation: `user.Errors` is missing user.Errors.AuthError, user.Errors.NotFoundError -// ╭─[ throw_throws_enum_variant_precise.baml:12:55 ] -// │ -// 12 │ function ThrowsVariantExact(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── throws contract violation: `user.Errors` is missing user.Errors.AuthError, user.Errors.NotFoundError -// │ -// │ Note: Error code: E0001 -// ────╯ -// Warning: extraneous throws declaration: user.Errors -// ╭─[ throw_throws_enum_variant_precise.baml:12:55 ] -// │ -// 12 │ function ThrowsVariantExact(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── extraneous throws declaration: user.Errors -// │ -// │ Note: Error code: E0001 -// ────╯ +// diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_variant_violation.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_variant_violation.baml index ae1912f5f8..f85d86c151 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_variant_violation.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_enum_variant_violation.baml @@ -16,21 +16,4 @@ function ThrowsVariantViolation(x: int) -> string throws Errors { //---- //- diagnostics -// Error: throws contract violation: `user.Errors` is missing user.Errors.AuthError, user.Errors.NotFoundError -// ╭─[ throw_throws_enum_variant_violation.baml:10:59 ] -// │ -// 10 │ function ThrowsVariantViolation(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── throws contract violation: `user.Errors` is missing user.Errors.AuthError, user.Errors.NotFoundError -// │ -// │ Note: Error code: E0001 -// ────╯ -// Warning: extraneous throws declaration: user.Errors -// ╭─[ throw_throws_enum_variant_violation.baml:10:59 ] -// │ -// 10 │ function ThrowsVariantViolation(x: int) -> string throws Errors { -// │ ───┬─── -// │ ╰───── extraneous throws declaration: user.Errors -// │ -// │ Note: Error code: E0001 -// ────╯ +// diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_mixed.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_mixed.baml index 81240beb70..370f7f9a54 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_mixed.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/throws_mixed.baml @@ -15,21 +15,4 @@ function ThrowsMixed(x: int) -> string throws string | Errors { //---- //- diagnostics -// Error: throws contract violation: `string | user.Errors` is missing user.Errors.AuthError -// ╭─[ throw_throws_mixed.baml:9:46 ] -// │ -// 9 │ function ThrowsMixed(x: int) -> string throws string | Errors { -// │ ────────┬─────── -// │ ╰───────── throws contract violation: `string | user.Errors` is missing user.Errors.AuthError -// │ -// │ Note: Error code: E0001 -// ───╯ -// Warning: extraneous throws declaration: user.Errors -// ╭─[ throw_throws_mixed.baml:9:46 ] -// │ -// 9 │ function ThrowsMixed(x: int) -> string throws string | Errors { -// │ ────────┬─────── -// │ ╰───────── extraneous throws declaration: user.Errors -// │ -// │ Note: Error code: E0001 -// ───╯ +// diff --git a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded.snap b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded.snap index 77118e1cde..c33d91f95d 100644 --- a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded.snap +++ b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/tests/bytecode_format/main.rs +assertion_line: 42 --- function assert.contains(haystack: string, needle: string) -> null { 0 load_var 1 (haystack) @@ -113,7 +114,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector { 10 return } -function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: (() -> testing.TestReport) -> () -> testing.TestReport?) -> null { +function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> null { 0 load_var 1 (self) 1 load_field 0 (prefix) 2 load_const 0 ("") @@ -213,7 +214,7 @@ function testing.TestCollector.register_test(self: null, name: string, body: () 96 jump -69 (to 27) } -function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: (() -> testing.TestSetReport) -> () -> testing.TestSetReport?) -> null { +function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> null { 0 load_var 1 (self) 1 load_field 0 (prefix) 2 load_const 0 ("") @@ -622,7 +623,7 @@ function testing.TestRegistry.serialize(self: null) -> (testing.SerializedTest | 103 jump -95 (to 8) } -function testing.run_test(body: () -> null, runner: (() -> testing.TestReport) -> () -> testing.TestReport?) -> testing.TestReport { +function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> testing.TestReport { 0 load_var 1 1 make_cell 2 store_var 1 @@ -651,7 +652,7 @@ function testing.run_test(body: () -> null, runner: (() -> testing.TestReport) - 25 return } -function testing.run_testset(run_children: () -> testing.TestSetReport, runner: (() -> testing.TestSetReport) -> () -> testing.TestSetReport?) -> testing.TestSetReport { +function testing.run_testset(run_children: () -> testing.TestSetReport, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> testing.TestSetReport { 0 load_var 2 (runner) 1 load_const 0 (null) 2 cmp_op == diff --git a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded_unoptimized.snap b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded_unoptimized.snap index 77118e1cde..608f662f36 100644 --- a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded_unoptimized.snap +++ b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded_unoptimized.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/tests/bytecode_format/main.rs +assertion_line: 43 --- function assert.contains(haystack: string, needle: string) -> null { 0 load_var 1 (haystack) @@ -113,7 +114,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector { 10 return } -function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: (() -> testing.TestReport) -> () -> testing.TestReport?) -> null { +function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> null { 0 load_var 1 (self) 1 load_field 0 (prefix) 2 load_const 0 ("") @@ -213,7 +214,7 @@ function testing.TestCollector.register_test(self: null, name: string, body: () 96 jump -69 (to 27) } -function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: (() -> testing.TestSetReport) -> () -> testing.TestSetReport?) -> null { +function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> null { 0 load_var 1 (self) 1 load_field 0 (prefix) 2 load_const 0 ("") @@ -622,7 +623,7 @@ function testing.TestRegistry.serialize(self: null) -> (testing.SerializedTest | 103 jump -95 (to 8) } -function testing.run_test(body: () -> null, runner: (() -> testing.TestReport) -> () -> testing.TestReport?) -> testing.TestReport { +function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> testing.TestReport { 0 load_var 1 1 make_cell 2 store_var 1 @@ -651,7 +652,7 @@ function testing.run_test(body: () -> null, runner: (() -> testing.TestReport) - 25 return } -function testing.run_testset(run_children: () -> testing.TestSetReport, runner: (() -> testing.TestSetReport) -> () -> testing.TestSetReport?) -> testing.TestSetReport { +function testing.run_testset(run_children: () -> testing.TestSetReport, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> testing.TestSetReport { 0 load_var 2 (runner) 1 load_const 0 (null) 2 cmp_op == diff --git a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_textual.snap b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_textual.snap index 23bffa31a8..eabcca30a7 100644 --- a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_textual.snap +++ b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_textual.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/tests/bytecode_format/main.rs +assertion_line: 41 --- function assert.contains(haystack: string, needle: string) -> null { load_var haystack @@ -147,7 +148,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector { return } -function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: (() -> testing.TestReport) -> () -> testing.TestReport?) -> null { +function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> null { load_var self load_field .prefix load_const "" @@ -271,7 +272,7 @@ function testing.TestCollector.register_test(self: null, name: string, body: () jump L3 } -function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: (() -> testing.TestSetReport) -> () -> testing.TestSetReport?) -> null { +function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> null { load_var self load_field .prefix load_const "" @@ -766,7 +767,7 @@ function testing.TestRegistry.serialize(self: null) -> (testing.SerializedTest | jump L0 } -function testing.run_test(body: () -> null, runner: (() -> testing.TestReport) -> () -> testing.TestReport?) -> testing.TestReport { +function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> testing.TestReport { load_var ?1 make_cell store_var ?1 @@ -801,7 +802,7 @@ function testing.run_test(body: () -> null, runner: (() -> testing.TestReport) - return } -function testing.run_testset(run_children: () -> testing.TestSetReport, runner: (() -> testing.TestSetReport) -> () -> testing.TestSetReport?) -> testing.TestSetReport { +function testing.run_testset(run_children: () -> testing.TestSetReport, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> testing.TestSetReport { load_var runner load_const null cmp_op == From a5d14b8f3a93e46e3fe16efcc57aa5015483ea21 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 00:05:07 -0500 Subject: [PATCH 12/26] Extract callable_boundary module and tighten effect-polymorphic rethrows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consolidate signature lowering (params, ret, throws, effect vars) into a shared `callable_boundary` module used by builder, inference, and throw_inference — eliminating duplicated logic. Only widen outward throws for callbacks that are actually invoked in the body, add diagnostic formatting that renders `__throws_` type vars as human-readable messages, and extend test coverage for under-declared rethrows, unused callbacks, stored callback enforcement, and generic class methods. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../baml_std/testing/registry.baml | 14 +- .../baml_std/testing/types.baml | 4 +- .../crates/baml_compiler2_hir/src/builder.rs | 57 ++- .../crates/baml_compiler2_tir/src/builder.rs | 433 +++--------------- .../src/callable_boundary.rs | 154 +++++++ .../src/effective_throws.rs | 55 ++- .../baml_compiler2_tir/src/infer_context.rs | 31 +- .../baml_compiler2_tir/src/inference.rs | 104 +++-- .../crates/baml_compiler2_tir/src/lib.rs | 1 + .../baml_compiler2_tir/src/throw_inference.rs | 68 +-- .../src/throws_semantics.rs | 113 +++++ .../class_field_fn_throws.baml | 17 + .../function_type_throws/hof_throws.baml | 23 + .../stored_callback_enforcement.baml | 7 + .../baml_tests____testing_std____03_hir.snap | 4 +- ...baml_tests____testing_std____04_5_mir.snap | 92 ++-- .../baml_tests____testing_std____04_tir.snap | 32 +- ...ml_tests____testing_std____06_codegen.snap | 20 +- ...rows__01_lexer__class_field_fn_throws.snap | 95 ++++ ...ion_type_throws__01_lexer__hof_throws.snap | 132 ++++++ ...01_lexer__stored_callback_enforcement.snap | 55 +++ ...ows__02_parser__class_field_fn_throws.snap | 151 ++++++ ...on_type_throws__02_parser__hof_throws.snap | 150 ++++++ ...2_parser__stored_callback_enforcement.snap | 39 ++ ...l_tests__function_type_throws__03_hir.snap | 30 ++ ...tests__function_type_throws__04_5_mir.snap | 212 +++++++++ ...l_tests__function_type_throws__04_tir.snap | 102 ++++- ..._function_type_throws__05_diagnostics.snap | 46 +- ...sts__function_type_throws__06_codegen.snap | 60 +++ 29 files changed, 1685 insertions(+), 616 deletions(-) create mode 100644 baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs diff --git a/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml b/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml index 80b9de728d..c3f055e3f4 100644 --- a/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml +++ b/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml @@ -187,18 +187,16 @@ function run_test(body: () -> null, runner: TestRunner?) -> TestReport { runs: [RunReport { outcome: result.outcome, duration_ms: baml.sys.now_ms() - start }], } } - let effective_run = match (runner) { - null => base_run, - r: TestRunner => r(base_run) + match (runner) { + null => base_run(), + r: TestRunner => r(base_run)() } - effective_run() } // Run all children of a testset with optional runner middleware. function run_testset(run_children: () -> TestSetReport, runner: TestSetRunner?) -> TestSetReport { - let effective_run = match (runner) { - null => run_children, - r: TestSetRunner => r(run_children) + match (runner) { + null => run_children(), + r: TestSetRunner => r(run_children)() } - effective_run() } diff --git a/baml_language/crates/baml_builtins2/baml_std/testing/types.baml b/baml_language/crates/baml_builtins2/baml_std/testing/types.baml index 8f1927cc6e..4b52f57cfb 100644 --- a/baml_language/crates/baml_builtins2/baml_std/testing/types.baml +++ b/baml_language/crates/baml_builtins2/baml_std/testing/types.baml @@ -31,7 +31,7 @@ type ChildReport = TestReport | TestSetReport // Runner type aliases — lambda transformers. // TestRunner wraps "run test once" into a new execution lambda. -type TestRunner = (() -> TestReport) -> () -> TestReport +type TestRunner = (() -> TestReport throws unknown) -> (() -> TestReport throws unknown) // TestSetRunner wraps "run all children" into a new execution lambda. -type TestSetRunner = (() -> TestSetReport) -> () -> TestSetReport +type TestSetRunner = (() -> TestSetReport throws unknown) -> (() -> TestSetReport throws unknown) diff --git a/baml_language/crates/baml_compiler2_hir/src/builder.rs b/baml_language/crates/baml_compiler2_hir/src/builder.rs index c8193399fc..f61e2c805f 100644 --- a/baml_language/crates/baml_compiler2_hir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_hir/src/builder.rs @@ -1036,11 +1036,19 @@ impl<'db> SemanticIndexBuilder<'db> { Self::collect_unknown_type_attrs(v, diagnostics); } } - ast::TypeExpr::Function { params, ret, .. } => { + ast::TypeExpr::Function { + params, + ret, + throws, + .. + } => { for p in params { Self::collect_unknown_type_attrs(&p.ty, diagnostics); } Self::collect_unknown_type_attrs(ret, diagnostics); + if let Some(throws) = throws { + Self::collect_unknown_type_attrs(throws, diagnostics); + } } _ => {} } @@ -1058,11 +1066,19 @@ impl<'db> SemanticIndexBuilder<'db> { ast::TypeExpr::Union { variants, .. } => { variants.iter().any(Self::type_expr_contains_rust) } - ast::TypeExpr::Function { params, ret, .. } => { + ast::TypeExpr::Function { + params, + ret, + throws, + .. + } => { params .iter() .any(|param| Self::type_expr_contains_rust(¶m.ty)) || Self::type_expr_contains_rust(ret) + || throws + .as_ref() + .is_some_and(|throws| Self::type_expr_contains_rust(throws)) } _ => false, } @@ -1135,18 +1151,31 @@ impl<'db> SemanticIndexBuilder<'db> { .collect::>() .join(" | "), ast::TypeExpr::Literal { value, .. } => value.to_string(), - ast::TypeExpr::Function { params, ret, .. } => format!( - "({}) -> {}", - params - .iter() - .map(|param| match ¶m.name { - Some(name) => format!("{}: {}", name, Self::render_type_expr(¶m.ty)), - None => Self::render_type_expr(¶m.ty), - }) - .collect::>() - .join(", "), - Self::render_type_expr(ret) - ), + ast::TypeExpr::Function { + params, + ret, + throws, + .. + } => { + let mut rendered = format!( + "({}) -> {}", + params + .iter() + .map(|param| match ¶m.name { + Some(name) => + format!("{}: {}", name, Self::render_type_expr(¶m.ty)), + None => Self::render_type_expr(¶m.ty), + }) + .collect::>() + .join(", "), + Self::render_type_expr(ret) + ); + if let Some(throws) = throws { + rendered.push_str(" throws "); + rendered.push_str(&Self::render_type_expr(throws)); + } + rendered + } ast::TypeExpr::BuiltinUnknown { .. } => "unknown".to_string(), ast::TypeExpr::Type { .. } => "type".to_string(), ast::TypeExpr::Rust { .. } => "$rust_type".to_string(), diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index 72fb9f129d..827b1d8c0e 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -26,6 +26,7 @@ use rustc_hash::{FxHashMap, FxHashSet}; use text_size::TextRange; use crate::{ + callable_boundary::lower_callable_boundary, infer_context::{InferContext, RelatedLocation, TirTypeError, TypeCheckDiagnostics}, package_interface::PackageResolutionContext, throws_semantics, @@ -1866,12 +1867,12 @@ impl<'db> TypeInferenceBuilder<'db> { let mut extra: Vec = diff .uncovered_effective .into_iter() - .map(|ty| ty.to_string()) + .map(|ty| throws_semantics::format_throw_fact_for_diagnostic(&ty)) .collect(); let mut extraneous: Vec = diff .extraneous_declared .into_iter() - .map(|ty| ty.to_string()) + .map(|ty| throws_semantics::format_throw_fact_for_diagnostic(&ty)) .collect(); extra.sort(); extraneous.sort(); @@ -2340,9 +2341,17 @@ impl<'db> TypeInferenceBuilder<'db> { } fn catch_base_throw_types(&self, base_expr_id: ExprId, body: &ExprBody) -> BTreeSet { - let mut out = BTreeSet::new(); - self.collect_throw_facts_from_expr(base_expr_id, body, &mut out); - out + crate::effective_throws::collect_effective_throws_from_root_expr( + self.context.db(), + self.package_id, + base_expr_id, + body, + &self.expressions, + &self.catch_residual_throws, + &self.aliases, + true, + true, + ) } /// Join a set of throw fact types into a single type. @@ -2505,272 +2514,6 @@ impl<'db> TypeInferenceBuilder<'db> { ) } - fn collect_throw_facts_from_expr( - &self, - expr_id: ExprId, - body: &ExprBody, - out: &mut BTreeSet, - ) { - match &body.exprs[expr_id] { - Expr::Throw { value } => { - self.collect_throw_facts_from_expr(*value, body, out); - self.collect_throw_facts_from_value(*value, body, out); - } - Expr::Call { callee, args } => { - self.collect_throw_facts_from_expr(*callee, body, out); - for arg in args { - self.collect_throw_facts_from_expr(*arg, body, out); - } - let type_level_facts = self - .expressions - .get(callee) - .and_then(|ty| throws_semantics::function_throws_facts(ty, &self.aliases)) - .and_then(|facts| if facts.is_empty() { None } else { Some(facts) }); - if let Some(facts) = type_level_facts { - out.extend(facts); - } else if let Some(target) = self.call_target_name(*callee, body) { - let throws = crate::throw_inference::function_throw_sets( - self.context.db(), - self.package_id, - ); - if let Some(transitive) = throws.transitive_for(&target) { - out.extend(transitive.iter().cloned()); - } else { - out.insert(Ty::Unknown { - attr: TyAttr::default(), - }); - } - } else { - out.insert(Ty::Unknown { - attr: TyAttr::default(), - }); - } - } - Expr::If { - condition, - then_branch, - else_branch, - } => { - self.collect_throw_facts_from_expr(*condition, body, out); - self.collect_throw_facts_from_expr(*then_branch, body, out); - if let Some(else_expr) = else_branch { - self.collect_throw_facts_from_expr(*else_expr, body, out); - } - } - Expr::Match { - scrutinee, arms, .. - } => { - self.collect_throw_facts_from_expr(*scrutinee, body, out); - for arm_id in arms { - let arm = &body.match_arms[*arm_id]; - if let Some(guard) = arm.guard { - self.collect_throw_facts_from_expr(guard, body, out); - } - self.collect_throw_facts_from_expr(arm.body, body, out); - } - } - Expr::Binary { lhs, rhs, .. } => { - self.collect_throw_facts_from_expr(*lhs, body, out); - self.collect_throw_facts_from_expr(*rhs, body, out); - } - Expr::Unary { expr, .. } => self.collect_throw_facts_from_expr(*expr, body, out), - Expr::Object { - fields, spreads, .. - } => { - for (_, value) in fields { - self.collect_throw_facts_from_expr(*value, body, out); - } - for spread in spreads { - self.collect_throw_facts_from_expr(spread.expr, body, out); - } - } - Expr::Array { elements } => { - for elem in elements { - self.collect_throw_facts_from_expr(*elem, body, out); - } - } - Expr::Map { entries } => { - for (key, value) in entries { - self.collect_throw_facts_from_expr(*key, body, out); - self.collect_throw_facts_from_expr(*value, body, out); - } - } - Expr::Block { stmts, tail_expr } => { - for stmt in stmts { - self.collect_throw_facts_from_stmt(*stmt, body, out); - } - if let Some(tail) = tail_expr { - self.collect_throw_facts_from_expr(*tail, body, out); - } - } - Expr::FieldAccess { base, .. } | Expr::OptionalFieldAccess { base, .. } => { - self.collect_throw_facts_from_expr(*base, body, out); - } - Expr::Index { base, index } | Expr::OptionalIndex { base, index } => { - self.collect_throw_facts_from_expr(*base, body, out); - self.collect_throw_facts_from_expr(*index, body, out); - } - Expr::OptionalCall { callee, args } => { - self.collect_throw_facts_from_expr(*callee, body, out); - for arg in args { - self.collect_throw_facts_from_expr(*arg, body, out); - } - let type_level_facts = self - .expressions - .get(callee) - .and_then(|ty| throws_semantics::function_throws_facts(ty, &self.aliases)) - .and_then(|facts| if facts.is_empty() { None } else { Some(facts) }); - if let Some(facts) = type_level_facts { - out.extend(facts); - } else if let Some(target) = self.call_target_name(*callee, body) { - let throws = crate::throw_inference::function_throw_sets( - self.context.db(), - self.package_id, - ); - if let Some(transitive) = throws.transitive_for(&target) { - out.extend(transitive.iter().cloned()); - } else { - out.insert(Ty::Unknown { - attr: TyAttr::default(), - }); - } - } else { - out.insert(Ty::Unknown { - attr: TyAttr::default(), - }); - } - } - Expr::Catch { base, .. } => { - self.collect_throw_facts_from_expr(*base, body, out); - } - Expr::OptionalChain { expr } => { - self.collect_throw_facts_from_expr(*expr, body, out); - } - Expr::Lambda(_) - | Expr::Literal(_) - | Expr::ByteStringLiteral(_) - | Expr::Null - | Expr::Path(_) - | Expr::Missing => {} - } - } - - fn collect_throw_facts_from_stmt( - &self, - stmt_id: StmtId, - body: &ExprBody, - out: &mut BTreeSet, - ) { - match &body.stmts[stmt_id] { - Stmt::Expr(expr_id) => self.collect_throw_facts_from_expr(*expr_id, body, out), - Stmt::Let { initializer, .. } => { - if let Some(init) = initializer { - self.collect_throw_facts_from_expr(*init, body, out); - } - } - Stmt::While { - condition, - body: while_body, - after, - .. - } => { - self.collect_throw_facts_from_expr(*condition, body, out); - self.collect_throw_facts_from_expr(*while_body, body, out); - if let Some(after_stmt) = after { - self.collect_throw_facts_from_stmt(*after_stmt, body, out); - } - } - Stmt::For { - collection, - body: for_body, - .. - } => { - self.collect_throw_facts_from_expr(*collection, body, out); - self.collect_throw_facts_from_expr(*for_body, body, out); - } - Stmt::Return(expr) => { - if let Some(expr) = expr { - self.collect_throw_facts_from_expr(*expr, body, out); - } - } - Stmt::Assign { target, value } | Stmt::AssignOp { target, value, .. } => { - self.collect_throw_facts_from_expr(*target, body, out); - self.collect_throw_facts_from_expr(*value, body, out); - } - Stmt::Throw { value } => { - self.collect_throw_facts_from_expr(*value, body, out); - self.collect_throw_facts_from_value(*value, body, out); - } - Stmt::Break | Stmt::Continue | Stmt::Missing | Stmt::HeaderComment { .. } => {} - } - } - - fn collect_throw_facts_from_value( - &self, - value_expr_id: ExprId, - body: &ExprBody, - out: &mut BTreeSet, - ) { - if let Some(thrown_ty) = self.expressions.get(&value_expr_id) { - out.extend(throws_semantics::flatten_ty_to_facts(thrown_ty)); - return; - } - - match &body.exprs[value_expr_id] { - Expr::Literal(lit) => out.extend(throws_semantics::flatten_ty_to_facts(&Ty::Literal( - lit.clone(), - Freshness::Regular, - TyAttr::default(), - ))), - Expr::ByteStringLiteral(_) => { - out.insert(Ty::Primitive(PrimitiveType::Uint8Array, TyAttr::default())); - } - Expr::Null => { - out.insert(Ty::Primitive(PrimitiveType::Null, TyAttr::default())); - } - _ => { - out.insert(Ty::Unknown { - attr: TyAttr::default(), - }); - } - } - } - - fn call_target_name(&self, callee_expr_id: ExprId, body: &ExprBody) -> Option { - let segments = Self::expr_to_path_segments(callee_expr_id, body)?; - if segments.len() < 2 { - // Single-segment path (free function) — return as-is. - return if segments.is_empty() { - None - } else { - Some(segments[0].clone()) - }; - } - // Multi-segment: receiver.method — resolve the receiver's type to get - // the class name so the target matches throw_inference's "Class.method" keys. - let receiver = &segments[0]; - let method = &segments[1]; - if let Some(Ty::Class(qn, _)) = self.locals.get(receiver) { - let ns = qn.namespace(); - let key = if ns.is_empty() { - format!("{}.{}", qn.name(), method) - } else { - let ns_str = ns.iter().map(Name::as_str).collect::>().join("."); - format!("{}.{}.{}", ns_str, qn.name(), method) - }; - Some(Name::new(key)) - } else { - // Receiver not a known local or not a class — fall back to raw path - Some(Name::new( - segments - .iter() - .map(Name::as_str) - .collect::>() - .join("."), - )) - } - } - fn find_function_scope_id( &self, file: SourceFile, @@ -2835,101 +2578,51 @@ impl<'db> TypeInferenceBuilder<'db> { ) -> Ty { let db = self.context.db(); let mut diags = Vec::new(); - let mut synthetic_effect_vars: Vec = Vec::new(); + let boundary = lower_callable_boundary( + db, + pkg_items, + ns_context, + generic_params, + sig, + self_param_ty, + &mut diags, + ); - let params: Vec<(Option, Ty)> = sig - .params - .iter() - .map(|(n, te)| { - let param_ty = if n.as_str() == "self" - && matches!(te, baml_compiler2_ast::TypeExpr::Unknown { .. }) - { - self_param_ty.cloned().unwrap_or(Ty::Unknown { + let throws_ty = boundary.explicit_throws.clone().unwrap_or_else(|| { + if boundary.direct_callback_effect_vars.is_empty() { + match (function_key, body_package_id) { + (Some(function_key), Some(body_package_id)) => { + crate::throw_inference::function_throw_sets(db, body_package_id) + .transitive_for(function_key) + .cloned() + .map(throws_semantics::concrete_throws_ty_from_facts) + .unwrap_or(Ty::Never { + attr: TyAttr::default(), + }) + } + _ => Ty::Never { attr: TyAttr::default(), - }) - } else { - crate::lower_type_expr::lower_type_expr_with_fn_context( - db, - te, - pkg_items, - ns_context, - generic_params, - &mut diags, - &crate::lower_type_expr::FnTypeLoweringContext::DirectParamRoot { - param_name: n.clone(), - }, - &mut synthetic_effect_vars, - ) + }, + } + } else { + let body_throws_facts = match (body_scope, body_package_id, body) { + (Some(body_scope), Some(body_package_id), Some(body)) => { + self.infer_concrete_body_throws(body_scope, body_package_id, body) + } + _ => BTreeSet::new(), }; - (Some(n.clone()), param_ty) - }) - .collect(); - - let effective_generic_params: Vec = generic_params - .iter() - .cloned() - .chain(synthetic_effect_vars.iter().cloned()) - .collect(); - - let ret_ty = sig - .return_type - .as_ref() - .map(|te| { - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - pkg_items, - ns_context, - &effective_generic_params, - &mut diags, - ) - }) - .unwrap_or(Ty::Unknown { - attr: TyAttr::default(), - }); - - let throws_ty = sig - .throws - .as_ref() - .map(|te| { - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - pkg_items, - ns_context, - generic_params, - &mut diags, - ) - }) - .unwrap_or_else(|| { - if synthetic_effect_vars.is_empty() { - match (function_key, body_package_id) { - (Some(function_key), Some(body_package_id)) => { - crate::throw_inference::function_throw_sets(db, body_package_id) - .transitive_for(function_key) - .cloned() - .map(throws_semantics::concrete_throws_ty_from_facts) - .unwrap_or(Ty::Never { - attr: TyAttr::default(), - }) - } - _ => Ty::Never { - attr: TyAttr::default(), - }, + let used_effect_vars = match body { + Some(baml_compiler2_hir::body::FunctionBody::Expr(expr_body)) => { + boundary.used_direct_callback_effect_vars(expr_body) } - } else { - let body_throws_facts = match (body_scope, body_package_id, body) { - (Some(body_scope), Some(body_package_id), Some(body)) => { - self.infer_concrete_body_throws(body_scope, body_package_id, body) - } - _ => BTreeSet::new(), - }; - throws_semantics::combine_effect_vars_with_body_throws( - &synthetic_effect_vars, - body_throws_facts, - ) - } - }); + _ => Vec::new(), + }; + throws_semantics::combine_effect_vars_with_body_throws( + &used_effect_vars, + body_throws_facts, + ) + } + }); // This helper reconstructs function types during lookup/inference rather than // at the defining declaration site, so any lowering diagnostics must stay @@ -2937,25 +2630,13 @@ impl<'db> TypeInferenceBuilder<'db> { drop(diags); Ty::Function { - params, - ret: Box::new(ret_ty), + params: boundary.params, + ret: Box::new(boundary.ret), throws: Box::new(throws_ty), attr: TyAttr::default(), } } - fn expr_to_path_segments(expr_id: ExprId, body: &ExprBody) -> Option> { - match &body.exprs[expr_id] { - Expr::Path(segments) if !segments.is_empty() => Some(segments.clone()), - Expr::FieldAccess { base, field } => { - let mut base_segments = Self::expr_to_path_segments(*base, body)?; - base_segments.push(field.clone()); - Some(base_segments) - } - _ => None, - } - } - fn infer_literal(lit: &baml_base::Literal) -> Ty { Ty::Literal(lit.clone(), Freshness::Fresh, TyAttr::default()) } diff --git a/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs b/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs new file mode 100644 index 0000000000..612d710ff8 --- /dev/null +++ b/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs @@ -0,0 +1,154 @@ +use baml_base::Name; +use baml_compiler2_ast::{Expr, ExprBody, TypeExpr}; +use baml_compiler2_hir::{package::PackageItems, signature::FunctionSignature}; + +use crate::{ + infer_context::TirTypeError, + lower_type_expr::{ + FnTypeLoweringContext, lower_type_expr_in_ns, lower_type_expr_with_fn_context, + }, + ty::{Ty, TyAttr}, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct DirectCallbackEffectVar { + pub param_name: Name, + pub effect_var: Name, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct LoweredCallableBoundary { + pub params: Vec<(Option, Ty)>, + pub ret: Ty, + pub explicit_throws: Option, + pub direct_callback_effect_vars: Vec, +} + +impl LoweredCallableBoundary { + pub(crate) fn used_direct_callback_effect_vars(&self, body: &ExprBody) -> Vec { + let directly_invoked = directly_invoked_callback_params(body); + self.direct_callback_effect_vars + .iter() + .filter(|effect_var| directly_invoked.contains(&effect_var.param_name)) + .map(|effect_var| effect_var.effect_var.clone()) + .collect() + } +} + +/// Lower a named callable boundary (function/method signature) with the typed +/// rethrows rule applied exactly once. +/// +/// Only the outermost omitted `throws` on a direct function-typed parameter is +/// effect-polymorphic. Everything else remains closed-by-default. +pub(crate) fn lower_callable_boundary<'db>( + db: &'db dyn crate::Db, + package_items: &PackageItems<'db>, + ns_context: &[Name], + generic_params: &[Name], + sig: &FunctionSignature, + self_param_ty: Option<&Ty>, + diagnostics: &mut Vec, +) -> LoweredCallableBoundary { + let mut direct_callback_effect_vars = Vec::new(); + let mut all_synthetic_effect_vars = Vec::new(); + + let params: Vec<(Option, Ty)> = sig + .params + .iter() + .map(|(param_name, param_ty_expr)| { + let ty = if param_name.as_str() == "self" + && matches!(param_ty_expr, TypeExpr::Unknown { .. }) + { + self_param_ty.cloned().unwrap_or(Ty::Unknown { + attr: TyAttr::default(), + }) + } else { + let mut param_effect_vars = Vec::new(); + let lowered = lower_type_expr_with_fn_context( + db, + param_ty_expr, + package_items, + ns_context, + generic_params, + diagnostics, + &FnTypeLoweringContext::DirectParamRoot { + param_name: param_name.clone(), + }, + &mut param_effect_vars, + ); + if let Some(effect_var) = param_effect_vars.first() { + direct_callback_effect_vars.push(DirectCallbackEffectVar { + param_name: param_name.clone(), + effect_var: effect_var.clone(), + }); + } + all_synthetic_effect_vars.extend(param_effect_vars); + lowered + }; + (Some(param_name.clone()), ty) + }) + .collect(); + + let effective_generic_params: Vec = generic_params + .iter() + .cloned() + .chain(all_synthetic_effect_vars.iter().cloned()) + .collect(); + + let ret = sig + .return_type + .as_ref() + .map(|te| { + lower_type_expr_in_ns( + db, + te, + package_items, + ns_context, + &effective_generic_params, + diagnostics, + ) + }) + .unwrap_or(Ty::Unknown { + attr: TyAttr::default(), + }); + + let explicit_throws = sig.throws.as_ref().map(|te| { + lower_type_expr_in_ns( + db, + te, + package_items, + ns_context, + generic_params, + diagnostics, + ) + }); + + LoweredCallableBoundary { + params, + ret, + explicit_throws, + direct_callback_effect_vars, + } +} + +/// Syntactic scan for direct callback parameter invocation sites. +/// +/// A direct callback position is only considered open when the body directly +/// calls the parameter path itself: `f()` or `f?.()`. +pub(crate) fn directly_invoked_callback_params(body: &ExprBody) -> Vec { + let mut invoked = Vec::new(); + for (_, expr) in body.exprs.iter() { + let callee = match expr { + Expr::Call { callee, .. } | Expr::OptionalCall { callee, .. } => *callee, + _ => continue, + }; + + if let Expr::Path(segments) = &body.exprs[callee] + && segments.len() == 1 + && !invoked.contains(&segments[0]) + { + invoked.push(segments[0].clone()); + } + } + invoked +} diff --git a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs index 3e597fdb24..a89847665b 100644 --- a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs +++ b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs @@ -21,22 +21,49 @@ pub(crate) fn collect_effective_throws<'db>( aliases: &HashMap, include_typevars: bool, unknown_on_unresolved_call: bool, +) -> BTreeSet { + body.root_expr + .map(|root| { + collect_effective_throws_from_root_expr( + db, + package_id, + root, + body, + expressions, + catch_residual_throws, + aliases, + include_typevars, + unknown_on_unresolved_call, + ) + }) + .unwrap_or_default() +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn collect_effective_throws_from_root_expr<'db>( + db: &'db dyn crate::Db, + package_id: PackageId<'db>, + root_expr: ExprId, + body: &ExprBody, + expressions: &FxHashMap, + catch_residual_throws: &FxHashMap>, + aliases: &HashMap, + include_typevars: bool, + unknown_on_unresolved_call: bool, ) -> BTreeSet { let mut out = BTreeSet::new(); - if let Some(root) = body.root_expr { - collect_effective_throws_from_expr( - db, - package_id, - root, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - &mut out, - ); - } + collect_effective_throws_from_expr( + db, + package_id, + root_expr, + body, + expressions, + catch_residual_throws, + aliases, + include_typevars, + unknown_on_unresolved_call, + &mut out, + ); out } diff --git a/baml_language/crates/baml_compiler2_tir/src/infer_context.rs b/baml_language/crates/baml_compiler2_tir/src/infer_context.rs index 5de680aff2..b63ab64cb0 100644 --- a/baml_language/crates/baml_compiler2_tir/src/infer_context.rs +++ b/baml_language/crates/baml_compiler2_tir/src/infer_context.rs @@ -156,9 +156,12 @@ impl fmt::Display for TirTypeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { TirTypeError::TypeMismatch { expected, got } => { + let expected = crate::throws_semantics::format_ty_for_diagnostic(expected); + let got = crate::throws_semantics::format_ty_for_diagnostic(got); write!(f, "type mismatch: expected {expected}, got {got}") } TirTypeError::UnresolvedMember { base_type, member } => { + let base_type = crate::throws_semantics::format_ty_for_diagnostic(base_type); write!(f, "type `{base_type}` has no member `{member}`") } TirTypeError::UnresolvedName { name } => { @@ -179,21 +182,27 @@ impl fmt::Display for TirTypeError { ) } TirTypeError::NotCallable { ty } => { + let ty = crate::throws_semantics::format_ty_for_diagnostic(ty); write!(f, "`{ty}` is not a function — it cannot be called") } TirTypeError::NotIterable { ty } => { + let ty = crate::throws_semantics::format_ty_for_diagnostic(ty); write!(f, "cannot iterate over type `{ty}`") } TirTypeError::NotIndexable { ty } => { + let ty = crate::throws_semantics::format_ty_for_diagnostic(ty); write!(f, "type `{ty}` is not indexable") } TirTypeError::InvalidBinaryOp { op, lhs, rhs } => { + let lhs = crate::throws_semantics::format_ty_for_diagnostic(lhs); + let rhs = crate::throws_semantics::format_ty_for_diagnostic(rhs); write!( f, "operator `{op:?}` cannot be applied to `{lhs}` and `{rhs}`" ) } TirTypeError::InvalidUnaryOp { op, operand } => { + let operand = crate::throws_semantics::format_ty_for_diagnostic(operand); write!(f, "operator `{op:?}` cannot be applied to `{operand}`") } TirTypeError::UnresolvedType { name, suggestions } => { @@ -217,6 +226,7 @@ impl fmt::Display for TirTypeError { write!(f, "expected {expected} argument(s), got {got}") } TirTypeError::MissingReturn { expected } => { + let expected = crate::throws_semantics::format_ty_for_diagnostic(expected); write!(f, "missing return: expected `{expected}`") } TirTypeError::AliasCycle { name } => { @@ -229,6 +239,8 @@ impl fmt::Display for TirTypeError { scrutinee_type, missing_cases, } => { + let scrutinee_type = + crate::throws_semantics::format_ty_for_diagnostic(scrutinee_type); write!( f, "non-exhaustive match on `{scrutinee_type}`; missing: {}", @@ -243,11 +255,14 @@ impl fmt::Display for TirTypeError { TirTypeError::ThrowsContractViolation { declared, extra_types, - } => write!( - f, - "throws contract violation: `{declared}` is missing {}", - extra_types.join(", ") - ), + } => { + let declared = crate::throws_semantics::format_throws_ty_for_diagnostic(declared); + write!( + f, + "throws contract violation: `{declared}` is missing {}", + extra_types.join(", ") + ) + } TirTypeError::ExtraneousThrowsDeclaration { extra_types } => write!( f, "extraneous throws declaration: {}", @@ -319,10 +334,12 @@ impl fmt::Display for TirTypeError { ) } TirTypeError::StoredFunctionRequiresExplicitThrows { actual_throws } => { + let actual_throws = + crate::throws_semantics::format_throws_ty_for_diagnostic(actual_throws); write!( f, - "function that `throws {actual_throws}` cannot be stored in a position typed `throws never`; \ - add an explicit `throws {actual_throws}` annotation to the stored function type" + "function whose escaping throws are {actual_throws} cannot be stored in a position typed `throws never`; \ + add an explicit `throws` annotation to the stored function type" ) } } diff --git a/baml_language/crates/baml_compiler2_tir/src/inference.rs b/baml_language/crates/baml_compiler2_tir/src/inference.rs index 83dd242313..9defde50fc 100644 --- a/baml_language/crates/baml_compiler2_tir/src/inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/inference.rs @@ -26,6 +26,7 @@ use text_size::TextRange; use crate::{ builder::TypeInferenceBuilder, + callable_boundary::lower_callable_boundary, infer_context::{InferContext, TypeCheckDiagnostics}, ty::{Ty, TyAttr}, }; @@ -338,24 +339,48 @@ pub fn infer_scope_types<'db>( builder.set_generic_params(generic_params.clone()); if let FunctionBody::Expr(expr_body) = body.as_ref() { - // Get declared return type - let mut diags = Vec::new(); - let return_ty = sig - .return_type - .as_ref() - .map(|te| { - crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - pkg_items, - &pkg_info.namespace_path, - &generic_params, - &mut diags, + let enclosing_class_name: Option = + scope.parent.and_then(|parent_idx| { + let parent = &index.scopes[parent_idx.index() as usize]; + if matches!(parent.kind, ScopeKind::Class) { + parent.name.clone() + } else { + None + } + }); + let self_param_ty = enclosing_class_name.as_ref().and_then(|cn| { + let ns_path = &pkg_info.namespace_path; + pkg_items.lookup_type(ns_path, cn).map(|def| { + Ty::Class( + crate::lower_type_expr::qualify_def(db, def, cn), + TyAttr::default(), ) }) - .unwrap_or(Ty::Unknown { - attr: TyAttr::default(), - }); + }); + + let boundary = lower_callable_boundary( + db, + pkg_items, + &pkg_info.namespace_path, + &generic_params, + sig.as_ref(), + self_param_ty.as_ref(), + &mut Vec::new(), + ); + + // Get declared return type + let mut diags = Vec::new(); + if let Some(te) = sig.return_type.as_ref() { + let _ = crate::lower_type_expr::lower_type_expr_in_ns( + db, + te, + pkg_items, + &pkg_info.namespace_path, + &generic_params, + &mut diags, + ); + } + let return_ty = boundary.ret.clone(); // Report unresolved type diagnostics for return type if !diags.is_empty() { @@ -373,49 +398,29 @@ pub fn infer_scope_types<'db>( // Set declared return type for return statement checking builder.set_return_type(return_ty.clone()); - // Determine enclosing class name for `self` parameter resolution - let enclosing_class_name: Option = - scope.parent.and_then(|parent_idx| { - let parent = &index.scopes[parent_idx.index() as usize]; - if matches!(parent.kind, ScopeKind::Class) { - parent.name.clone() - } else { - None - } - }); - // Add parameter bindings as locals let sig_sm = baml_compiler2_hir::signature::function_signature_source_map( db, func_loc, ); - for (i, (param_name, param_te)) in sig.params.iter().enumerate() { - let param_ty = if param_name.as_str() == "self" - && matches!(param_te, baml_compiler2_ast::TypeExpr::Unknown { .. }) + for (i, ((param_name, param_te), (_, param_ty))) in + sig.params.iter().zip(boundary.params.iter()).enumerate() + { + if !(param_name.as_str() == "self" + && matches!(param_te, baml_compiler2_ast::TypeExpr::Unknown { .. })) { - // `self` parameter with no type annotation — infer from enclosing class - enclosing_class_name - .as_ref() - .and_then(|cn| { - let ns_path = &pkg_info.namespace_path; - pkg_items.lookup_type(ns_path, cn).map(|def| { - Ty::Class( - crate::lower_type_expr::qualify_def(db, def, cn), - TyAttr::default(), - ) - }) - }) - .unwrap_or(Ty::Unknown { - attr: TyAttr::default(), - }) - } else { let mut param_diags = Vec::new(); - let ty = crate::lower_type_expr::lower_type_expr_in_ns( + let mut param_effect_vars = Vec::new(); + let _ = crate::lower_type_expr::lower_type_expr_with_fn_context( db, param_te, pkg_items, &pkg_info.namespace_path, &generic_params, &mut param_diags, + &crate::lower_type_expr::FnTypeLoweringContext::DirectParamRoot { + param_name: param_name.clone(), + }, + &mut param_effect_vars, ); if !param_diags.is_empty() { let span = sig_sm @@ -429,9 +434,8 @@ pub fn infer_scope_types<'db>( builder.report_at_span(diag, span); } } - ty - }; - builder.add_local(param_name.clone(), param_ty); + } + builder.add_local(param_name.clone(), param_ty.clone()); } // Check root expression against declared return type diff --git a/baml_language/crates/baml_compiler2_tir/src/lib.rs b/baml_language/crates/baml_compiler2_tir/src/lib.rs index bc99ba0df8..a291556bb5 100644 --- a/baml_language/crates/baml_compiler2_tir/src/lib.rs +++ b/baml_language/crates/baml_compiler2_tir/src/lib.rs @@ -18,6 +18,7 @@ pub mod analysis; pub mod builder; +pub mod callable_boundary; pub mod cycle_detector; pub mod effective_throws; pub mod generics; diff --git a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs index 33f1135c7f..4d64faa995 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs @@ -15,6 +15,7 @@ use baml_compiler2_hir::{ }; use crate::{ + callable_boundary::{directly_invoked_callback_params, lower_callable_boundary}, inference::collect_type_aliases, lower_type_expr::{lower_type_expr_in_ns, qualify_def}, throws_semantics::function_throws_facts, @@ -520,64 +521,37 @@ fn collect_direct_param_call_throws<'db>( body: &ExprBody, aliases: &HashMap, ) -> BTreeSet { - let mut direct_param_throws: HashMap> = HashMap::new(); - - for (param_name, param_ty_expr) in &sig.params { - let mut diags = Vec::new(); - let lowered = lower_type_expr_in_ns( - db, - param_ty_expr, - pkg_items, - ns_context, - generic_params, - &mut diags, - ); - drop(diags); - - let Some(facts) = function_throws_facts(&lowered, aliases) else { - continue; - }; - - let concrete_facts: BTreeSet = facts - .into_iter() - .filter(|fact| !matches!(fact, Ty::TypeVar(_, _) | Ty::Never { .. } | Ty::Void { .. })) - .collect(); - - if !concrete_facts.is_empty() { - direct_param_throws.insert(param_name.clone(), concrete_facts); - } - } - - if direct_param_throws.is_empty() { + let boundary = lower_callable_boundary( + db, + pkg_items, + ns_context, + generic_params, + sig, + None, + &mut Vec::new(), + ); + let directly_invoked = directly_invoked_callback_params(body); + + if directly_invoked.is_empty() { return BTreeSet::new(); } let mut out = BTreeSet::new(); - for (_, expr) in body.exprs.iter() { - let callee = match expr { - Expr::Call { callee, .. } | Expr::OptionalCall { callee, .. } => *callee, - _ => continue, - }; - - let Some(param_name) = direct_param_callee_name(callee, body) else { + for ((param_name, _), (_, param_ty)) in sig.params.iter().zip(boundary.params.iter()) { + if !directly_invoked.contains(param_name) { continue; - }; - - if let Some(facts) = direct_param_throws.get(¶m_name) { - out.extend(facts.iter().cloned()); } + let Some(facts) = function_throws_facts(param_ty, aliases) else { + continue; + }; + out.extend(facts.into_iter().filter(|fact| { + !matches!(fact, Ty::TypeVar(_, _) | Ty::Never { .. } | Ty::Void { .. }) + })); } out } -fn direct_param_callee_name(expr_id: baml_compiler2_ast::ExprId, body: &ExprBody) -> Option { - match &body.exprs[expr_id] { - Expr::Path(segments) if segments.len() == 1 => Some(segments[0].clone()), - _ => None, - } -} - /// Look up a function's transitive throw set from dependency interfaces. fn lookup_dep_throw_set<'a>( dep_interfaces: &'a [(Name, &crate::package_interface::PackageInterface)], diff --git a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs index d9a195d3cf..fc0544af8e 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs @@ -232,6 +232,119 @@ pub(crate) fn function_shape_matches_ignoring_outer_throws( } } +pub(crate) fn format_throw_fact_for_diagnostic(ty: &Ty) -> String { + match ty { + Ty::TypeVar(name, _) if name.as_str().starts_with("__throws_") => { + format!( + "throws from callback parameter `{}`", + &name.as_str()["__throws_".len()..] + ) + } + _ => format_ty_for_diagnostic(ty), + } +} + +pub(crate) fn format_throws_ty_for_diagnostic(ty: &Ty) -> String { + match ty { + Ty::Union(members, _) => members + .iter() + .map(format_throw_fact_for_diagnostic) + .collect::>() + .join(" | "), + other => format_throw_fact_for_diagnostic(other), + } +} + +fn needs_postfix_parens_for_diagnostic(ty: &Ty) -> bool { + matches!(ty, Ty::Union(..) | Ty::Function { .. }) +} + +fn format_postfix_base_for_diagnostic(ty: &Ty) -> String { + let rendered = format_ty_for_diagnostic(ty); + if needs_postfix_parens_for_diagnostic(ty) { + format!("({rendered})") + } else { + rendered + } +} + +pub(crate) fn format_ty_for_diagnostic(ty: &Ty) -> String { + match ty { + Ty::Class(qn, _) | Ty::Enum(qn, _) | Ty::TypeAlias(qn, _) => qn.to_string(), + Ty::EnumVariant(qn, variant, _) => format!("{qn}.{variant}"), + Ty::Primitive(p, _) => p.to_string(), + Ty::List(inner, _) => format!("{}[]", format_postfix_base_for_diagnostic(inner)), + Ty::Map(key, value, _) => { + format!( + "map<{}, {}>", + format_ty_for_diagnostic(key), + format_ty_for_diagnostic(value) + ) + } + Ty::Union(members, _) => members + .iter() + .map(format_ty_for_diagnostic) + .collect::>() + .join(" | "), + Ty::Optional(inner, _) => format!("{}?", format_postfix_base_for_diagnostic(inner)), + Ty::Literal(lit, _, _) => lit.to_string(), + Ty::EvolvingList(inner, _) => { + if matches!(inner.as_ref(), Ty::Never { .. }) { + "_[]".to_string() + } else { + format!("{}[] (evolving)", format_postfix_base_for_diagnostic(inner)) + } + } + Ty::EvolvingMap(key, value, _) => { + if matches!(key.as_ref(), Ty::Never { .. }) + && matches!(value.as_ref(), Ty::Never { .. }) + { + "map<_, _>".to_string() + } else { + format!( + "map<{}, {}> (evolving)", + format_ty_for_diagnostic(key), + format_ty_for_diagnostic(value) + ) + } + } + Ty::Function { + params, + ret, + throws, + .. + } => { + let rendered_params = params + .iter() + .map(|(name, ty)| match name { + Some(name) => format!("{name}: {}", format_ty_for_diagnostic(ty)), + None => format_ty_for_diagnostic(ty), + }) + .collect::>() + .join(", "); + let mut rendered = format!("({rendered_params}) -> {}", format_ty_for_diagnostic(ret)); + if !matches!(throws.as_ref(), Ty::Never { .. }) { + rendered.push_str(" throws "); + rendered.push_str(&format_throws_ty_for_diagnostic(throws)); + } + rendered + } + Ty::TypeVar(name, _) if name.as_str().starts_with("__throws_") => { + format!( + "throws from callback parameter `{}`", + &name.as_str()["__throws_".len()..] + ) + } + Ty::TypeVar(name, _) => name.to_string(), + Ty::Never { .. } => "never".to_string(), + Ty::Void { .. } => "void".to_string(), + Ty::BuiltinUnknown { .. } | Ty::Unknown { .. } => "unknown".to_string(), + Ty::RustType { .. } => "$rust_type".to_string(), + Ty::Type { .. } => "type".to_string(), + Ty::Error { .. } => "!error".to_string(), + } +} + fn declared_covers_fact( declared: &Ty, fact: &Ty, diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/class_field_fn_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/class_field_fn_throws.baml index cc1489a0c0..326b879a46 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/class_field_fn_throws.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/class_field_fn_throws.baml @@ -13,6 +13,18 @@ class MixedHandler { risky: () -> int throws string } +class MethodRunner { + value: T + + function apply(self, f: (T) -> int) -> int { + f(self.value) + } + + function apply_underdeclared(self, f: (T) -> int) -> int throws never { + f(self.value) + } +} + // Function returning a class with function fields function make_pure_handler() -> PureHandler { PureHandler { run: () -> null { null } } @@ -21,3 +33,8 @@ function make_pure_handler() -> PureHandler { function make_throwing_handler() -> ThrowingHandler { ThrowingHandler { run: () -> null { throw "error" } } } + +function test_method_runner() -> int { + let runner = MethodRunner { value: 1 } + runner.apply((x: int) -> int { x + 1 }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml index 8d1c32fc4a..ecf2564709 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml @@ -53,3 +53,26 @@ function test_apply_with_helper_pure() -> int { function test_apply_with_helper_throwing() -> int { apply_with_helper(() -> int { throw 42 }) } + +// Explicitly under-declaring a direct callback rethrow should fail contract checking. +function apply_underdeclared(f: () -> int) -> int throws never { + f() +} + +// Unused callback params should not widen the wrapper outward throws. +function ignore_callback(f: () -> int) -> int { + 42 +} + +function test_ignore_callback_throwing() -> int { + ignore_callback(() -> int { throw "unused" }) +} + +// Concrete body throws still escape even when the callback is never invoked. +function ignore_callback_but_throw(f: () -> int) -> int { + helper_with_body_throw() +} + +function test_ignore_callback_but_throw() -> int { + ignore_callback_but_throw(() -> int { throw 42 }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml b/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml index cb1260e7b9..f25cff29c3 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml @@ -87,3 +87,10 @@ function make_stored_closure_good() -> (() -> int) { function make_stored_closure_with_throws() -> (() -> int throws string) { return () -> int { throw "oops" } } + +// ERROR: direct callback params become effect polymorphic at the function boundary, +// but storing them still requires an explicit stored function throws annotation. +function test_store_direct_callback_param(f: () -> int) -> int { + let stored: () -> int = f + stored() +} diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____03_hir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____03_hir.snap index d1cda7ba82..2c2a5609f1 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____03_hir.snap @@ -58,13 +58,13 @@ function testing.register_test_set(self: ?, name: string, collector: (testing.Te { let full_name = if (self.prefix Eq "") { } name else { } self.prefix Add "/" Add name; let count = 0; let hash_prefix = full_name Add "#"; for ts in self.testsets { } if (ts.name Eq full_name Or ts.name.startsWith(hash_prefix)) { count Add= 1 } else { } null; let final_name = if (count Gt 0) { } full_name Add "#" Add baml.unstable.string(count Add 1) else { } full_name; let _ = self.testsets.push(testing.TestSetRegistration { name: final_name, collector: collector, runner: runner }) } null } function testing.run_test(body: () -> null, runner: testing.TestRunner?) -> testing.TestReport [expr] { - { let base_run = () -> TestReport { { let start = baml.sys.now_ms(); let result = { body() } RunReport { outcome: "pass", duration_ms: 0 } catch_all (e) { _ => { } RunReport { outcome: "fail", duration_ms: 0 } } } TestReport { outcome: result.outcome, runs: [RunReport { outcome: result.outcome, duration_ms: baml.sys.now_ms() Sub start }] } }; let effective_run = match (runner) { null => base_run, r: testing.TestRunner => r(base_run) } } effective_run() + { let base_run = () -> TestReport { { let start = baml.sys.now_ms(); let result = { body() } RunReport { outcome: "pass", duration_ms: 0 } catch_all (e) { _ => { } RunReport { outcome: "fail", duration_ms: 0 } } } TestReport { outcome: result.outcome, runs: [RunReport { outcome: result.outcome, duration_ms: baml.sys.now_ms() Sub start }] } } } match (runner) { null => base_run(), r: testing.TestRunner => r(base_run)() } } function testing.run_test(self: ?, name: string) -> testing.TestReport [expr] { { for t in self.collector.tests { } if (t.name Eq name) { return root.run_test(t.body, t.runner) }; for k in self.expansions.keys() { let sub = self.expansions[k] } if (name.startsWith(k Add "/")) { return sub.run_test(name) }; throw "Test not found: " Add name } } function testing.run_testset(run_children: () -> testing.TestSetReport, runner: testing.TestSetRunner?) -> testing.TestSetReport [expr] { - { let effective_run = match (runner) { null => run_children, r: testing.TestSetRunner => r(run_children) } } effective_run() + { } match (runner) { null => run_children(), r: testing.TestSetRunner => r(run_children)() } } function testing.serialize(self: ?) -> testing.SerializedTestDef[] [expr] { { let items: testing.SerializedTestDef[] = []; for t in self.collector.tests { } items.push(testing.SerializedTest { type: "test", name: t.name }); for ts in self.collector.testsets { let expanded = self.expansions.get(ts.name) } match (expanded) { sub: testing.TestRegistry => { } items.push(testing.SerializedTestSet { name: ts.name, items: sub.serialize(), loadingTimeMs: sub.loading_time_ms }), null => { } items.push(testing.SerializedTest { type: "lazyTestSet", name: ts.name }) }; return items } diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap index 01b9ab2047..fb390d3b0f 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap @@ -309,13 +309,13 @@ fn testing.TestCollector.new(prefix: string) -> testing.TestCollector { } } -fn testing.TestCollector.register_test(self: testing.TestCollector, name: string, body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> null { +fn testing.TestCollector.register_test(self: testing.TestCollector, name: string, body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { // Locals: let _0: null // _0 // return let _1: testing.TestCollector // self // param let _2: string // name // param let _3: () -> null // body // param - let _4: ((() -> testing.TestReport) -> () -> testing.TestReport)? // runner // param + let _4: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)? // runner // param let _5: string // full_name let _6: bool let _7: string @@ -469,13 +469,13 @@ fn testing.TestCollector.register_test(self: testing.TestCollector, name: string } } -fn testing.TestCollector.register_test_set(self: testing.TestCollector, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> null { +fn testing.TestCollector.register_test_set(self: testing.TestCollector, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { // Locals: let _0: null // _0 // return let _1: testing.TestCollector // self // param let _2: string // name // param let _3: (testing.TestCollector) -> null // collector // param - let _4: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)? // runner // param + let _4: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)? // runner // param let _5: string // full_name let _6: bool let _7: string @@ -629,54 +629,53 @@ fn testing.TestCollector.register_test_set(self: testing.TestCollector, name: st } } -fn testing.run_test(body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> testing.TestReport { +fn testing.run_test(body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { // Locals: let _0: testing.TestReport // _0 // return let _1: () -> null // body // param [captured] - let _2: ((() -> testing.TestReport) -> () -> testing.TestReport)? // runner // param + let _2: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)? // runner // param let _3: () -> testing.TestReport // base_run - let _4: () -> testing.TestReport // effective_run - let _5: bool + let _4: bool + let _5: () -> testing.TestReport let _6: bool - let _7: (() -> testing.TestReport) -> () -> testing.TestReport // r - let _8: (() -> testing.TestReport) -> () -> testing.TestReport - let _9: () -> testing.TestReport + let _7: (() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown // r + let _8: () -> testing.TestReport throws unknown + let _9: (() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown let _10: () -> testing.TestReport bb0: { _3 = make_closure lambda[0](copy _1); - _5 = copy _2 == const null; - branch copy _5 -> [bb5, bb1]; + _4 = copy _2 == const null; + branch copy _4 -> [bb6, bb1]; } bb1: { - _6 = is_type(copy _2, Function { params: [Function { params: [], ret: Class(TypeName { name: "TestReport", module_path: ["testing"], display_name: "testing.TestReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }], ret: Function { params: [], ret: Class(TypeName { name: "TestReport", module_path: ["testing"], display_name: "testing.TestReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }, throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + _6 = is_type(copy _2, Function { params: [Function { params: [], ret: Class(TypeName { name: "TestReport", module_path: ["testing"], display_name: "testing.TestReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: Some(BuiltinUnknown { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }), attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }], ret: Function { params: [], ret: Class(TypeName { name: "TestReport", module_path: ["testing"], display_name: "testing.TestReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: Some(BuiltinUnknown { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }), attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }, throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); branch copy _6 -> [bb3, bb2]; } bb2: { - goto -> bb6; + goto -> bb8; } bb3: { _7 = copy _2; - _8 = copy _7; - _9 = copy _3; - _4 = call copy _8(copy _9) -> [bb4]; + _9 = copy _7; + _10 = copy _3; + _8 = call copy _9(copy _10) -> [bb4]; } bb4: { - goto -> bb6; + _0 = call copy _8() -> [bb5]; } bb5: { - _4 = copy _3; - goto -> bb6; + goto -> bb8; } bb6: { - _10 = copy _4; - _0 = call copy _10() -> [bb7]; + _5 = copy _3; + _0 = call copy _5() -> [bb7]; } bb7: { @@ -684,6 +683,10 @@ fn testing.run_test(body: () -> null, runner: ((() -> testing.TestReport) -> () } bb8: { + goto -> bb9; + } + + bb9: { return; } } @@ -773,7 +776,7 @@ fn testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) -> te let _12: testing.TestRegistration let _13: () -> null throws unknown let _14: testing.TestRegistration - let _15: ((() -> testing.TestReport) -> () -> testing.TestReport)? + let _15: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)? let _16: testing.TestRegistration let _17: string[] let _18: map @@ -901,50 +904,47 @@ fn testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) -> te } } -fn testing.run_testset(run_children: () -> testing.TestSetReport, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> testing.TestSetReport { +fn testing.run_testset(run_children: () -> testing.TestSetReport, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> testing.TestSetReport { // Locals: let _0: testing.TestSetReport // _0 // return let _1: () -> testing.TestSetReport // run_children // param - let _2: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)? // runner // param - let _3: () -> testing.TestSetReport // effective_run + let _2: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)? // runner // param + let _3: bool let _4: bool - let _5: bool - let _6: (() -> testing.TestSetReport) -> () -> testing.TestSetReport // r - let _7: (() -> testing.TestSetReport) -> () -> testing.TestSetReport - let _8: () -> testing.TestSetReport + let _5: (() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown // r + let _6: () -> testing.TestSetReport throws unknown + let _7: (() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown bb0: { - _4 = copy _2 == const null; - branch copy _4 -> [bb5, bb1]; + _3 = copy _2 == const null; + branch copy _3 -> [bb6, bb1]; } bb1: { - _5 = is_type(copy _2, Function { params: [Function { params: [], ret: Class(TypeName { name: "TestSetReport", module_path: ["testing"], display_name: "testing.TestSetReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }], ret: Function { params: [], ret: Class(TypeName { name: "TestSetReport", module_path: ["testing"], display_name: "testing.TestSetReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }, throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); - branch copy _5 -> [bb3, bb2]; + _4 = is_type(copy _2, Function { params: [Function { params: [], ret: Class(TypeName { name: "TestSetReport", module_path: ["testing"], display_name: "testing.TestSetReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: Some(BuiltinUnknown { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }), attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }], ret: Function { params: [], ret: Class(TypeName { name: "TestSetReport", module_path: ["testing"], display_name: "testing.TestSetReport" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] }), throws: Some(BuiltinUnknown { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }), attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }, throws: None, attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + branch copy _4 -> [bb3, bb2]; } bb2: { - goto -> bb6; + goto -> bb8; } bb3: { - _6 = copy _2; - _7 = copy _6; - _3 = call copy _7(copy _1) -> [bb4]; + _5 = copy _2; + _7 = copy _5; + _6 = call copy _7(copy _1) -> [bb4]; } bb4: { - goto -> bb6; + _0 = call copy _6() -> [bb5]; } bb5: { - _3 = copy _1; - goto -> bb6; + goto -> bb8; } bb6: { - _8 = copy _3; - _0 = call copy _8() -> [bb7]; + _0 = call copy _1() -> [bb7]; } bb7: { @@ -952,6 +952,10 @@ fn testing.run_testset(run_children: () -> testing.TestSetReport, runner: ((() - } bb8: { + goto -> bb9; + } + + bb9: { return; } } diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap index cfe42b2e21..cdfa0c4ace 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap @@ -206,7 +206,7 @@ function testing.TestRegistry.expand_set(self: testing.TestRegistry, name: strin throw "TestSet not found..." + name : string } } -function testing.run_test(body: () -> null throws __throws_body, runner: testing.TestRunner?) -> testing.TestReport throws __throws_body { +function testing.run_test(body: () -> null throws __throws_body, runner: testing.TestRunner?) -> testing.TestReport throws __throws_body | unknown { { : testing.TestReport let base_run = : () -> testing.TestReport () -> TestReport { ... } : () -> testing.TestReport @@ -216,26 +216,22 @@ function testing.run_test(body: () -> null throws __throws_body, runner: testing { 1 stmts + tail } catch_all (e) { _ => { 0 stmts + tail } } TestReport { outcome: result.outcome, runs: [RunReport { outcome: result.outcome, duration_ms: baml.sys.now_ms() - start }] } } - let effective_run = : () -> testing.TestReport - match (runner : testing.TestRunner?) : () -> testing.TestReport - null => - base_run : () -> testing.TestReport - r: TestRunner => - r(base_run) : () -> testing.TestReport - effective_run() : testing.TestReport + match (runner : testing.TestRunner?) : testing.TestReport + null => + base_run() : testing.TestReport + r: TestRunner => + r(base_run)() : testing.TestReport } } lambda testing.run_test { } -function testing.run_testset(run_children: () -> testing.TestSetReport throws __throws_run_children, runner: testing.TestSetRunner?) -> testing.TestSetReport throws __throws_run_children { +function testing.run_testset(run_children: () -> testing.TestSetReport throws __throws_run_children, runner: testing.TestSetRunner?) -> testing.TestSetReport throws __throws_run_children | unknown { { : testing.TestSetReport - let effective_run = : () -> testing.TestSetReport - match (runner : testing.TestSetRunner?) : () -> testing.TestSetReport - null => - run_children : () -> testing.TestSetReport - r: TestSetRunner => - r(run_children) : () -> testing.TestSetReport - effective_run() : testing.TestSetReport + match (runner : testing.TestSetRunner?) : testing.TestSetReport + null => + run_children() : testing.TestSetReport + r: TestSetRunner => + r(run_children)() : testing.TestSetReport } } class testing.TestRegistration$stream { @@ -287,8 +283,8 @@ class testing.TestSetReport { results: testing.ChildReport[] } type testing.ChildReport = testing.TestReport | testing.TestSetReport -type testing.TestRunner = (() -> testing.TestReport) -> () -> testing.TestReport -type testing.TestSetRunner = (() -> testing.TestSetReport) -> () -> testing.TestSetReport +type testing.TestRunner = (() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown +type testing.TestSetRunner = (() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown type testing.Outcome$stream = "pass" | "fail" | "error" class testing.RunReport$stream { outcome: null | testing.Outcome$stream diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap index d654398641..ff9b771e61 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap @@ -67,7 +67,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector { return } -function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> null { +function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { load_var self load_field .prefix load_const "" @@ -191,7 +191,7 @@ function testing.TestCollector.register_test(self: null, name: string, body: () jump L3 } -function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> null { +function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { load_var self load_field .prefix load_const "" @@ -686,7 +686,7 @@ function testing.TestRegistry.serialize(self: null) -> (testing.SerializedTest | jump L0 } -function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport) -> () -> testing.TestReport)?) -> testing.TestReport { +function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { load_var ?1 make_cell store_var ?1 @@ -708,20 +708,18 @@ function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport) load_var base_run load_var runner call_indirect - store_var effective_run + call_indirect jump L2 L1: load_var base_run - store_var effective_run + call_indirect L2: - load_var effective_run - call_indirect return } -function testing.run_testset(run_children: () -> testing.TestSetReport, runner: ((() -> testing.TestSetReport) -> () -> testing.TestSetReport)?) -> testing.TestSetReport { +function testing.run_testset(run_children: () -> testing.TestSetReport, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> testing.TestSetReport { load_var runner load_const null cmp_op == @@ -737,15 +735,13 @@ function testing.run_testset(run_children: () -> testing.TestSetReport, runner: load_var run_children load_var runner call_indirect - store_var effective_run + call_indirect jump L2 L1: load_var run_children - store_var effective_run + call_indirect L2: - load_var effective_run - call_indirect return } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__class_field_fn_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__class_field_fn_throws.snap index 02b0010e06..0c3e827ba0 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__class_field_fn_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__class_field_fn_throws.snap @@ -54,6 +54,64 @@ Word "int" Throws "throws" Word "string" RBrace "}" +Class "class" +Word "MethodRunner" +Less "<" +Word "T" +Greater ">" +LBrace "{" +Word "value" +Colon ":" +Word "T" +Function "function" +Word "apply" +LParen "(" +Word "self" +Comma "," +Word "f" +Colon ":" +LParen "(" +Word "T" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "f" +LParen "(" +Word "self" +Dot "." +Word "value" +RParen ")" +RBrace "}" +Function "function" +Word "apply_underdeclared" +LParen "(" +Word "self" +Comma "," +Word "f" +Colon ":" +LParen "(" +Word "T" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "never" +LBrace "{" +Word "f" +LParen "(" +Word "self" +Dot "." +Word "value" +RParen ")" +RBrace "}" +RBrace "}" Slash "/" Slash "/" Word "Function" @@ -106,3 +164,40 @@ Quote "\"" RBrace "}" RBrace "}" RBrace "}" +Function "function" +Word "test_method_runner" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "runner" +Equals "=" +Word "MethodRunner" +Less "<" +Word "int" +Greater ">" +LBrace "{" +Word "value" +Colon ":" +IntegerLiteral "1" +RBrace "}" +Word "runner" +Dot "." +Word "apply" +LParen "(" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "x" +Plus "+" +IntegerLiteral "1" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap index 326e0dc631..19c378b35f 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap @@ -306,3 +306,135 @@ IntegerLiteral "42" RBrace "}" RParen ")" RBrace "}" +Slash "/" +Slash "/" +Word "Explicitly" +Word "under-declaring" +Word "a" +Word "direct" +Word "callback" +Word "rethrow" +Word "should" +Word "fail" +Word "contract" +Word "checking" +Dot "." +Function "function" +Word "apply_underdeclared" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +Throws "throws" +Word "never" +LBrace "{" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Unused" +Word "callback" +Word "params" +Word "should" +Word "not" +Word "widen" +Word "the" +Word "wrapper" +Word "outward" +Throws "throws" +Dot "." +Function "function" +Word "ignore_callback" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +IntegerLiteral "42" +RBrace "}" +Function "function" +Word "test_ignore_callback_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "ignore_callback" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "unused" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Concrete" +Word "body" +Throws "throws" +Word "still" +Word "escape" +Word "even" +Word "when" +Word "the" +Word "callback" +Word "is" +Word "never" +Word "invoked" +Dot "." +Function "function" +Word "ignore_callback_but_throw" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "helper_with_body_throw" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_ignore_callback_but_throw" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "ignore_callback_but_throw" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap index bcada89efc..3f10e1fa86 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap @@ -546,3 +546,58 @@ Word "oops" Quote "\"" RBrace "}" RBrace "}" +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "direct" +Word "callback" +Word "params" +Word "become" +Word "effect" +Word "polymorphic" +Word "at" +Word "the" +Function "function" +Word "boundary" +Comma "," +Slash "/" +Slash "/" +Word "but" +Word "storing" +Word "them" +Word "still" +Word "requires" +Word "an" +Word "explicit" +Word "stored" +Function "function" +Throws "throws" +Word "annotation" +Dot "." +Function "function" +Word "test_store_direct_callback_param" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "stored" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +Equals "=" +Word "f" +Word "stored" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__class_field_fn_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__class_field_fn_throws.snap index cbeebc1074..c7cfc14120 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__class_field_fn_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__class_field_fn_throws.snap @@ -62,6 +62,99 @@ SOURCE_FILE TYPE_EXPR "string" WORD "string" R_BRACE "}" + CLASS_DEF + KW_CLASS "class" + WORD "MethodRunner" + GENERIC_PARAM_LIST + LESS "<" + GENERIC_PARAM "T" + WORD "T" + GREATER ">" + L_BRACE "{" + FIELD + WORD "value" + COLON ":" + TYPE_EXPR "T" + WORD "T" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply" + PARAMETER_LIST + L_PAREN "(" + PARAMETER "self" + WORD "self" + COMMA "," + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "T" + WORD "T" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS + L_PAREN "(" + PATH_EXPR "self.value" + WORD "self" + DOT "." + WORD "value" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_underdeclared" + PARAMETER_LIST + L_PAREN "(" + PARAMETER "self" + WORD "self" + COMMA "," + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR "T" + WORD "T" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "never" + WORD "never" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS + L_PAREN "(" + PATH_EXPR "self.value" + WORD "self" + DOT "." + WORD "value" + R_PAREN ")" + R_BRACE "}" + R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" WORD "make_pure_handler" @@ -130,6 +223,64 @@ SOURCE_FILE R_BRACE "}" R_BRACE "}" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_method_runner" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "runner" + EQUALS "=" + OBJECT_LITERAL + PATH_EXPR + WORD "MethodRunner" + GENERIC_ARGS + LESS "<" + TYPE_EXPR "int" + WORD "int" + GREATER ">" + L_BRACE "{" + OBJECT_FIELD "value: 1" + WORD "value" + COLON ":" + INTEGER_LITERAL "1" + R_BRACE "}" + CALL_EXPR + PATH_EXPR "runner.apply" + WORD "runner" + DOT "." + WORD "apply" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + BINARY_EXPR "x + 1" + WORD "x" + PLUS "+" + INTEGER_LITERAL "1" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" === ERRORS === None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap index ecfdd7c114..d975f965ff 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap @@ -355,6 +355,156 @@ SOURCE_FILE R_BRACE "}" R_PAREN ")" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_underdeclared" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "never" + WORD "never" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "ignore_callback" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR "{ + 42 +}" + L_BRACE "{" + INTEGER_LITERAL "42" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_ignore_callback_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "ignore_callback" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "unused" + QUOTE """ + WORD "unused" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "ignore_callback_but_throw" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "helper_with_body_throw" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_ignore_callback_but_throw" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "ignore_callback_but_throw" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" === ERRORS === None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap index d2b0649ece..48df04aae6 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__stored_callback_enforcement.snap @@ -507,6 +507,45 @@ SOURCE_FILE QUOTE """ R_BRACE "}" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_store_direct_callback_param" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "stored" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EQUALS "=" + WORD "f" + CALL_EXPR + WORD "stored" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" === ERRORS === None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index 431c8e0b62..cf15232a4d 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -39,6 +39,9 @@ function user.test_chained_pure() -> int [expr] { function user.test_chained_throwing() -> int [expr] { { } apply_outer(() -> int { { throw "deep" } }) } +class user.MethodRunner { + value: user.T +} class user.MixedHandler { safe: () -> int risky: () -> int @@ -49,12 +52,21 @@ class user.PureHandler { class user.ThrowingHandler { run: () -> null } +function user.apply(self: ?, f: (user.T) -> int) -> int [expr] { + { } f(self.value) +} +function user.apply_underdeclared(self: ?, f: (user.T) -> int) -> int [expr] { + { } f(self.value) +} function user.make_pure_handler() -> user.PureHandler [expr] { { } user.PureHandler { run: () -> null { { } null } } } function user.make_throwing_handler() -> user.ThrowingHandler [expr] { { } user.ThrowingHandler { run: () -> null { { throw "error" } } } } +function user.test_method_runner() -> int [expr] { + { let runner = user.MethodRunner { value: 1 } } runner.apply((x: int) -> int { { } x Add 1 }) +} function user.compose(f: (user.A) -> user.B, g: (user.B) -> user.C) -> (user.A) -> user.C [expr] { { return (a: A) -> C { { } g(f(a)) } } } @@ -263,12 +275,21 @@ function user.apply_and_throw(f: () -> int) -> int [expr] { function user.apply_throwing(f: () -> int) -> int [expr] { { } f() } +function user.apply_underdeclared(f: () -> int) -> int [expr] { + { } f() +} function user.apply_with_helper(f: () -> int) -> int [expr] { { helper_with_body_throw() } f() } function user.helper_with_body_throw() -> int [expr] { { throw "helper boom" } } +function user.ignore_callback(f: () -> int) -> int [expr] { + { } 42 +} +function user.ignore_callback_but_throw(f: () -> int) -> int [expr] { + { } helper_with_body_throw() +} function user.test_apply_and_throw_pure() -> int [expr] { { } apply_and_throw(() -> int { { } 42 }) } @@ -287,6 +308,12 @@ function user.test_apply_with_helper_pure() -> int [expr] { function user.test_apply_with_helper_throwing() -> int [expr] { { } apply_with_helper(() -> int { { throw 42 } }) } +function user.test_ignore_callback_but_throw() -> int [expr] { + { } ignore_callback_but_throw(() -> int { { throw 42 } }) +} +function user.test_ignore_callback_throwing() -> int [expr] { + { } ignore_callback(() -> int { { throw "unused" } }) +} function user.test_explicit_throws_match() -> int [expr] { { let f = () -> int throws string { { throw "boom" } } } f() } @@ -390,6 +417,9 @@ function user.make_stored_closure_with_throws() -> () -> int [expr] { function user.make_stored_throwing_handler() -> user.StoredThrowingHandler [expr] { { } user.StoredThrowingHandler { run: () -> null { { throw "error" } } } } +function user.test_store_direct_callback_param(f: () -> int) -> int [expr] { + { let stored: () -> int = f } stored() +} function user.test_stored_alias_error() -> null [expr] { { let cb: user.StoredPureCb = () -> int { { throw "error" } } } null } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index de02e5de6e..c05f0a9734 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -363,6 +363,48 @@ fn .() -> null { } } +fn user.MethodRunner.apply(self: MethodRunner, f: (void) -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: MethodRunner // self // param + let _2: (void) -> int // f // param + let _3: void + + bb0: { + _3 = copy _1.0; + _0 = call copy _2(copy _3) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +fn user.MethodRunner.apply_underdeclared(self: MethodRunner, f: (void) -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: MethodRunner // self // param + let _2: (void) -> int // f // param + let _3: void + + bb0: { + _3 = copy _1.0; + _0 = call copy _2(copy _3) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + fn user.make_pure_handler() -> PureHandler { // Locals: let _0: PureHandler // _0 // return @@ -420,6 +462,45 @@ fn .() -> null { } } +fn user.test_method_runner() -> int { + // Locals: + let _0: int // _0 // return + let _1: MethodRunner // runner + let _2: MethodRunner + let _3: (int) -> int + + bb0: { + _1 = MethodRunner { const 1_i64 }; + _2 = copy _1; + _3 = make_closure lambda[0](); + _0 = call const fn user.MethodRunner.apply(copy _2, copy _3) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .(x: int) -> null { + // Locals: + let _0: null // _0 // return + let _1: int // x // param + + bb0: { + _0 = copy _1 + const 1_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + fn user.compose(f: (void) -> void, g: (void) -> void) -> (void) -> void { // Locals: let _0: (void) -> void // _0 // return @@ -2270,6 +2351,24 @@ fn user.apply_throwing(f: () -> int throws string) -> int { } } +fn user.apply_underdeclared(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call copy _1() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + fn user.apply_with_helper(f: () -> int) -> int { // Locals: let _0: int // _0 // return @@ -2302,6 +2401,39 @@ fn user.helper_with_body_throw() -> int { } } +fn user.ignore_callback(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = const 42_i64; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.ignore_callback_but_throw(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + + bb0: { + _0 = call const fn user.helper_with_body_throw() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + fn user.test_apply_and_throw_pure() -> int { // Locals: let _0: int // _0 // return @@ -2491,6 +2623,64 @@ fn .() -> null { } } +fn user.test_ignore_callback_but_throw() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void throws int + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.ignore_callback_but_throw(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const 42_i64; + } +} + +fn user.test_ignore_callback_throwing() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void throws string + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.ignore_callback(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "unused"; + } +} + fn user.test_explicit_throws_match() -> int { // Locals: let _0: int // _0 // return @@ -3573,6 +3763,28 @@ fn .() -> null { } } +fn user.test_store_direct_callback_param(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + let _2: () -> int // stored + let _3: () -> int + + bb0: { + _2 = copy _1; + _3 = copy _2; + _0 = call copy _3() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + fn user.test_stored_alias_error() -> null { // Locals: let _0: null // _0 // return diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 31ec4435b9..2220b16135 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -96,7 +96,7 @@ function user.test_chained_pure() -> int throws never { } lambda user.test_chained_pure { } -function user.test_chained_throwing() -> int throws string { +function user.test_chained_throwing() -> int throws never { { : int apply_outer(() -> int { ... }) : int () -> int { ... } : () -> never throws string @@ -117,6 +117,20 @@ class user.MixedHandler { safe: () -> int risky: () -> int throws string } +class user.MethodRunner { + value: T +} +function user.MethodRunner.apply(self: user.MethodRunner, f: (unknown) -> int throws __throws_f) -> int throws __throws_f { + { : int + f(self.value) : int + } +} +function user.MethodRunner.apply_underdeclared(self: user.MethodRunner, f: (unknown) -> int throws __throws_f) -> int throws never { + { : int + f(self.value) : int + } + !! 402..408: throws contract violation: `never` is missing throws from callback parameter `f` +} function user.make_pure_handler() -> user.PureHandler throws never { { : user.PureHandler PureHandler { run: () -> null { ... } } : user.PureHandler @@ -131,6 +145,18 @@ function user.make_throwing_handler() -> user.ThrowingHandler throws never { } lambda user.make_throwing_handler { } +function user.test_method_runner() -> int throws never { + { : int + let runner = MethodRunner { value: 1 } : user.MethodRunner + runner.apply((x: int) -> int { ... }) : int + (x: int) -> int { ... } : (x: int) -> int + { + x + 1 + } + } +} +lambda user.test_method_runner { +} class user.PureHandler$stream { run: unknown } @@ -141,10 +167,13 @@ class user.MixedHandler$stream { safe: unknown risky: unknown } +class user.MethodRunner$stream { + value: null | unknown +} function user.compose(f: (A) -> B throws __throws_f, g: (B) -> C throws __throws_g) -> (A) -> C throws __throws_f | __throws_g { { : never - return : (a: A) -> C - (a: A) -> C { ... } : (a: A) -> C + return : (a: A) -> C throws __throws_f | __throws_g + (a: A) -> C { ... } : (a: A) -> C throws __throws_f | __throws_g { g(f(a)) } @@ -154,7 +183,7 @@ lambda user.compose { } function user.test_compose_pure() -> (int) -> string throws never { { : (int) -> string | "result" - compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | "result" + compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | "result" (x: int) -> int { ... } : (x: int) -> int { x * 2 @@ -169,9 +198,9 @@ lambda user.test_compose_pure { } lambda user.test_compose_pure { } -function user.test_compose_first_throws() -> (int) -> string throws string { +function user.test_compose_first_throws() -> (int) -> string throws never { { : (int) -> string | "result" - compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | "result" + compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | "result" (x: int) -> int { ... } : (x: int) -> never throws string { throw "f failed" @@ -186,9 +215,9 @@ lambda user.test_compose_first_throws { } lambda user.test_compose_first_throws { } -function user.test_compose_second_throws() -> (int) -> string throws string { +function user.test_compose_second_throws() -> (int) -> string throws never { { : (int) -> string | never - compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | never + compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | never (x: int) -> int { ... } : (x: int) -> int { x * 2 @@ -203,9 +232,9 @@ lambda user.test_compose_second_throws { } lambda user.test_compose_second_throws { } -function user.test_compose_both_throw() -> (int) -> string throws int | string { +function user.test_compose_both_throw() -> (int) -> string throws never { { : (int) -> string | never - compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | never + compose((x: int) -> int { ... }, (y: int) -> string { ... }) : (int) -> string | never (x: int) -> int { ... } : (x: int) -> never throws string { throw "string error" @@ -850,6 +879,44 @@ function user.test_apply_with_helper_throwing() -> int throws int | string { } lambda user.test_apply_with_helper_throwing { } +function user.apply_underdeclared(f: () -> int throws __throws_f) -> int throws never { + { : int + f() : int + } + !! 1490..1496: throws contract violation: `never` is missing throws from callback parameter `f` +} +function user.ignore_callback(f: () -> int throws __throws_f) -> int throws __throws_f { + { : 42 + 42 : 42 + } +} +function user.test_ignore_callback_throwing() -> int throws never { + { : int + ignore_callback(() -> int { ... }) : int + () -> int { ... } : () -> never throws string + { + throw "unused" + } + } +} +lambda user.test_ignore_callback_throwing { +} +function user.ignore_callback_but_throw(f: () -> int throws __throws_f) -> int throws __throws_f | string { + { : int + helper_with_body_throw() : int + } +} +function user.test_ignore_callback_but_throw() -> int throws string { + { : int + ignore_callback_but_throw(() -> int { ... }) : int + () -> int { ... } : () -> never throws int + { + throw 42 + } + } +} +lambda user.test_ignore_callback_but_throw { +} function user.test_explicit_throws_match() -> int throws string { { : never let f = : () -> never throws string @@ -1163,7 +1230,7 @@ function user.make_bad_stored_handler() -> user.StoredPureHandler throws never { { : user.StoredPureHandler StoredPureHandler { run: () -> null { ... } } : user.StoredPureHandler } - !! 378..407: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + !! 378..407: function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type } lambda user.make_bad_stored_handler { } @@ -1193,7 +1260,7 @@ function user.test_stored_local_error() -> null throws never { } null : null } - !! 1059..1087: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + !! 1059..1087: function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type } lambda user.test_stored_local_error { } @@ -1245,7 +1312,7 @@ function user.test_stored_alias_error() -> null throws never { } null : null } - !! 1875..1903: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + !! 1875..1903: function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type } lambda user.test_stored_alias_error { } @@ -1269,7 +1336,7 @@ function user.make_stored_closure_bad() -> () -> int throws never { throw "oops" } } - !! 2204..2231: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + !! 2204..2231: function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type } lambda user.make_stored_closure_bad { } @@ -1295,6 +1362,13 @@ function user.make_stored_closure_with_throws() -> () -> int throws string throw } lambda user.make_stored_closure_with_throws { } +function user.test_store_direct_callback_param(f: () -> int throws __throws_f) -> int throws __throws_f { + { : int + let stored = f : () -> int throws __throws_f + stored() : int + } + !! 2789..2790: function whose escaping throws are throws from callback parameter `f` cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type +} class user.StoredPureHandler$stream { run: unknown } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index 290499edb6..79ba8ab7ee 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -2,6 +2,26 @@ source: crates/baml_tests/src/generated_tests.rs --- === COMPILER2 DIAGNOSTICS === + [type] Error: throws contract violation: `never` is missing throws from callback parameter `f` + ╭─[ class_field_fn_throws.baml:23:66 ] + │ + 23 │ function apply_underdeclared(self, f: (T) -> int) -> int throws never { + │ ───┬── + │ ╰──── throws contract violation: `never` is missing throws from callback parameter `f` + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `never` is missing throws from callback parameter `f` + ╭─[ hof_throws.baml:58:57 ] + │ + 58 │ function apply_underdeclared(f: () -> int) -> int throws never { + │ ───┬── + │ ╰──── throws contract violation: `never` is missing throws from callback parameter `f` + │ + │ Note: Error code: E0001 +────╯ + [type] Warning: extraneous throws declaration: string ╭─[ lambda_throws_explicit.baml:17:27 ] │ @@ -72,22 +92,22 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ - [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + [type] Error: function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type ╭─[ stored_callback_enforcement.baml:12:27 ] │ 12 │ StoredPureHandler { run: () -> null { throw "error" } } │ ──────────────┬────────────── - │ ╰──────────────── function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + │ ╰──────────────── function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type │ │ Note: Error code: E0001 ────╯ - [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + [type] Error: function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type ╭─[ stored_callback_enforcement.baml:35:22 ] │ 35 │ let f: () -> null = () -> null { throw "oops" } │ ──────────────┬───────────── - │ ╰─────────────── function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + │ ╰─────────────── function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type │ │ Note: Error code: E0001 ────╯ @@ -112,22 +132,32 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ - [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + [type] Error: function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type ╭─[ stored_callback_enforcement.baml:64:25 ] │ 64 │ let cb: StoredPureCb = () -> int { throw "error" } │ ──────────────┬───────────── - │ ╰─────────────── function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + │ ╰─────────────── function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type │ │ Note: Error code: E0001 ────╯ - [type] Error: function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + [type] Error: function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type ╭─[ stored_callback_enforcement.baml:78:9 ] │ 78 │ return () -> int { throw "oops" } │ ─────────────┬───────────── - │ ╰─────────────── function that `throws string` cannot be stored in a position typed `throws never`; add an explicit `throws string` annotation to the stored function type + │ ╰─────────────── function whose escaping throws are string cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: function whose escaping throws are throws from callback parameter `f` cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type + ╭─[ stored_callback_enforcement.baml:94:27 ] + │ + 94 │ let stored: () -> int = f + │ ┬ + │ ╰── function whose escaping throws are throws from callback parameter `f` cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type │ │ Note: Error code: E0001 ────╯ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 9d005516e5..1f96421dfa 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -1,6 +1,22 @@ --- source: crates/baml_tests/src/generated_tests.rs --- +function user.MethodRunner.apply(self: null, f: (void) -> int) -> int { + load_var self + load_field .value + load_var f + call_indirect + return +} + +function user.MethodRunner.apply_underdeclared(self: null, f: (void) -> int) -> int { + load_var self + load_field .value + load_var f + call_indirect + return +} + function user.always_ok(x: int) -> string { load_const "always ok" return @@ -104,6 +120,12 @@ function user.apply_throwing(f: () -> int throws string) -> int { return } +function user.apply_underdeclared(f: () -> int) -> int { + load_var f + call_indirect + return +} + function user.apply_with_arg(x: int, f: (int) -> int) -> int { load_var x load_var f @@ -278,6 +300,16 @@ function user.helper_with_body_throw() -> int { throw } +function user.ignore_callback(f: () -> int) -> int { + load_const 42 + return +} + +function user.ignore_callback_but_throw(f: () -> int) -> int { + call user.helper_with_body_throw + return +} + function user.make_bad_stored_handler() -> StoredPureHandler { alloc_instance StoredPureHandler copy 0 @@ -655,6 +687,18 @@ function user.test_guarded_throwing() -> int { return } +function user.test_ignore_callback_but_throw() -> int { + make_closure ., 0 + call user.ignore_callback_but_throw + return +} + +function user.test_ignore_callback_throwing() -> int { + make_closure ., 0 + call user.ignore_callback + return +} + function user.test_lambda_throws_class() -> int { make_closure ., 0 call_indirect @@ -689,6 +733,16 @@ function user.test_map_it_throwing() -> string { return } +function user.test_method_runner() -> int { + alloc_instance MethodRunner + copy 0 + load_const 1 + store_field .value + make_closure ., 0 + call user.MethodRunner.apply + return +} + function user.test_mixed_pure() -> int { load_const 5 make_closure ., 0 @@ -776,6 +830,12 @@ function user.test_run_throwing() -> int { return } +function user.test_store_direct_callback_param(f: () -> int) -> int { + load_var f + call_indirect + return +} + function user.test_stored_alias_error() -> null { load_const null return From 8845cd3ddde86699113292702143e7ac23bd3647 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 00:35:21 -0500 Subject: [PATCH 13/26] Complete named callable throws boundary cleanup --- .../crates/baml_compiler2_tir/src/builder.rs | 28 +- .../src/callable_boundary.rs | 89 +- .../src/effective_throws.rs | 823 ++++-------------- .../baml_compiler2_tir/src/inference.rs | 102 +-- .../baml_compiler2_tir/src/throw_inference.rs | 10 +- .../src/throws_semantics.rs | 3 + .../crates/baml_compiler2_tir/src/ty.rs | 3 + .../nested_function_throws_validation.baml | 28 + .../function_type_throws/hof_throws.baml | 11 + ...ion_type_throws__01_lexer__hof_throws.snap | 70 ++ ...on_type_throws__02_parser__hof_throws.snap | 68 ++ ...l_tests__function_type_throws__03_hir.snap | 7 + ...tests__function_type_throws__04_5_mir.snap | 52 ++ ...l_tests__function_type_throws__04_tir.snap | 18 + ...sts__function_type_throws__06_codegen.snap | 13 + 15 files changed, 567 insertions(+), 758 deletions(-) create mode 100644 baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/nested_function_throws_validation.baml diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index 827b1d8c0e..da245af2c1 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -1808,8 +1808,24 @@ impl<'db> TypeInferenceBuilder<'db> { self.context.report_at_span(diag, span); } + self.check_lowered_throws_contract(body, Some(&declared_ty), throws_span, fallback_span); + } + + /// Validate an already-lowered declared `throws` type against effective escaping throws. + pub fn check_lowered_throws_contract( + &mut self, + body: &ExprBody, + declared_throws: Option<&Ty>, + throws_span: Option, + fallback_span: TextRange, + ) { + let Some(declared_ty) = declared_throws else { + return; + }; + let effective = self.collect_effective_throws(body); - self.report_throws_contract_diff_at_span(&declared_ty, &effective, span); + let span = throws_span.unwrap_or(fallback_span); + self.report_throws_contract_diff_at_span(declared_ty, &effective, span); } fn lower_lambda_throws_annotation_silently( @@ -2577,7 +2593,6 @@ impl<'db> TypeInferenceBuilder<'db> { self_param_ty: Option<&Ty>, ) -> Ty { let db = self.context.db(); - let mut diags = Vec::new(); let boundary = lower_callable_boundary( db, pkg_items, @@ -2585,7 +2600,6 @@ impl<'db> TypeInferenceBuilder<'db> { generic_params, sig, self_param_ty, - &mut diags, ); let throws_ty = boundary.explicit_throws.clone().unwrap_or_else(|| { @@ -2624,11 +2638,9 @@ impl<'db> TypeInferenceBuilder<'db> { } }); - // This helper reconstructs function types during lookup/inference rather than - // at the defining declaration site, so any lowering diagnostics must stay - // attached to the original signature instead of being re-emitted at each use. - drop(diags); - + // This helper reconstructs function types away from the declaration + // site, so the boundary's lowering diagnostics stay attached to the + // original signature instead of being re-emitted at each lookup. Ty::Function { params: boundary.params, ret: Box::new(boundary.ret), diff --git a/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs b/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs index 612d710ff8..1a5e9a153c 100644 --- a/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs +++ b/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs @@ -1,6 +1,7 @@ use baml_base::Name; use baml_compiler2_ast::{Expr, ExprBody, TypeExpr}; use baml_compiler2_hir::{package::PackageItems, signature::FunctionSignature}; +use rustc_hash::FxHashSet; use crate::{ infer_context::TirTypeError, @@ -22,6 +23,9 @@ pub(crate) struct LoweredCallableBoundary { pub ret: Ty, pub explicit_throws: Option, pub direct_callback_effect_vars: Vec, + pub param_diagnostics: Vec>, + pub ret_diagnostics: Vec, + pub throws_diagnostics: Vec, } impl LoweredCallableBoundary { @@ -47,45 +51,48 @@ pub(crate) fn lower_callable_boundary<'db>( generic_params: &[Name], sig: &FunctionSignature, self_param_ty: Option<&Ty>, - diagnostics: &mut Vec, ) -> LoweredCallableBoundary { let mut direct_callback_effect_vars = Vec::new(); let mut all_synthetic_effect_vars = Vec::new(); + let mut param_diagnostics = Vec::with_capacity(sig.params.len()); let params: Vec<(Option, Ty)> = sig .params .iter() .map(|(param_name, param_ty_expr)| { - let ty = if param_name.as_str() == "self" - && matches!(param_ty_expr, TypeExpr::Unknown { .. }) - { - self_param_ty.cloned().unwrap_or(Ty::Unknown { - attr: TyAttr::default(), - }) - } else { - let mut param_effect_vars = Vec::new(); - let lowered = lower_type_expr_with_fn_context( - db, - param_ty_expr, - package_items, - ns_context, - generic_params, - diagnostics, - &FnTypeLoweringContext::DirectParamRoot { - param_name: param_name.clone(), - }, - &mut param_effect_vars, + if param_name.as_str() == "self" && matches!(param_ty_expr, TypeExpr::Unknown { .. }) { + param_diagnostics.push(Vec::new()); + return ( + Some(param_name.clone()), + self_param_ty.cloned().unwrap_or(Ty::Unknown { + attr: TyAttr::default(), + }), ); - if let Some(effect_var) = param_effect_vars.first() { - direct_callback_effect_vars.push(DirectCallbackEffectVar { - param_name: param_name.clone(), - effect_var: effect_var.clone(), - }); - } - all_synthetic_effect_vars.extend(param_effect_vars); - lowered - }; - (Some(param_name.clone()), ty) + } + + let mut slot_diagnostics = Vec::new(); + let mut param_effect_vars = Vec::new(); + let lowered = lower_type_expr_with_fn_context( + db, + param_ty_expr, + package_items, + ns_context, + generic_params, + &mut slot_diagnostics, + &FnTypeLoweringContext::DirectParamRoot { + param_name: param_name.clone(), + }, + &mut param_effect_vars, + ); + if let Some(effect_var) = param_effect_vars.first() { + direct_callback_effect_vars.push(DirectCallbackEffectVar { + param_name: param_name.clone(), + effect_var: effect_var.clone(), + }); + } + all_synthetic_effect_vars.extend(param_effect_vars); + param_diagnostics.push(slot_diagnostics); + (Some(param_name.clone()), lowered) }) .collect(); @@ -95,6 +102,7 @@ pub(crate) fn lower_callable_boundary<'db>( .chain(all_synthetic_effect_vars.iter().cloned()) .collect(); + let mut ret_diagnostics = Vec::new(); let ret = sig .return_type .as_ref() @@ -105,13 +113,14 @@ pub(crate) fn lower_callable_boundary<'db>( package_items, ns_context, &effective_generic_params, - diagnostics, + &mut ret_diagnostics, ) }) .unwrap_or(Ty::Unknown { attr: TyAttr::default(), }); + let mut throws_diagnostics = Vec::new(); let explicit_throws = sig.throws.as_ref().map(|te| { lower_type_expr_in_ns( db, @@ -119,7 +128,7 @@ pub(crate) fn lower_callable_boundary<'db>( package_items, ns_context, generic_params, - diagnostics, + &mut throws_diagnostics, ) }); @@ -128,15 +137,24 @@ pub(crate) fn lower_callable_boundary<'db>( ret, explicit_throws, direct_callback_effect_vars, + param_diagnostics, + ret_diagnostics, + throws_diagnostics, } } /// Syntactic scan for direct callback parameter invocation sites. /// +/// This is a conservative syntactic check, not a flow-sensitive analysis. /// A direct callback position is only considered open when the body directly /// calls the parameter path itself: `f()` or `f?.()`. -pub(crate) fn directly_invoked_callback_params(body: &ExprBody) -> Vec { - let mut invoked = Vec::new(); +/// +/// Aliased callback calls like `let g = f; g()` are intentionally not tracked, +/// so the wrapper's outward throws will not include the callback's throws in +/// that case. A future extension could follow simple single-assignment `let` +/// aliases. +pub(crate) fn directly_invoked_callback_params(body: &ExprBody) -> FxHashSet { + let mut invoked = FxHashSet::default(); for (_, expr) in body.exprs.iter() { let callee = match expr { Expr::Call { callee, .. } | Expr::OptionalCall { callee, .. } => *callee, @@ -145,9 +163,8 @@ pub(crate) fn directly_invoked_callback_params(body: &ExprBody) -> Vec { if let Expr::Path(segments) = &body.exprs[callee] && segments.len() == 1 - && !invoked.contains(&segments[0]) { - invoked.push(segments[0].clone()); + invoked.insert(segments[0].clone()); } } invoked diff --git a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs index a89847665b..18602215f4 100644 --- a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs +++ b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs @@ -22,20 +22,18 @@ pub(crate) fn collect_effective_throws<'db>( include_typevars: bool, unknown_on_unresolved_call: bool, ) -> BTreeSet { + let context = EffectiveThrowsContext { + db, + package_id, + body, + expressions, + catch_residual_throws, + aliases, + include_typevars, + unknown_on_unresolved_call, + }; body.root_expr - .map(|root| { - collect_effective_throws_from_root_expr( - db, - package_id, - root, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - ) - }) + .map(|root| context.collect_effective_throws_from_root_expr(root)) .unwrap_or_default() } @@ -51,679 +49,216 @@ pub(crate) fn collect_effective_throws_from_root_expr<'db>( include_typevars: bool, unknown_on_unresolved_call: bool, ) -> BTreeSet { - let mut out = BTreeSet::new(); - collect_effective_throws_from_expr( + EffectiveThrowsContext { db, package_id, - root_expr, body, expressions, catch_residual_throws, aliases, include_typevars, unknown_on_unresolved_call, - &mut out, - ); - out -} - -#[derive(Clone, Copy)] -struct CallResolutionOptions { - include_typevars: bool, - unknown_on_unresolved_call: bool, + } + .collect_effective_throws_from_root_expr(root_expr) } -#[allow(clippy::too_many_arguments)] -fn collect_effective_throws_from_expr<'db>( +struct EffectiveThrowsContext<'a, 'db> { db: &'db dyn crate::Db, package_id: PackageId<'db>, - expr_id: ExprId, - body: &ExprBody, - expressions: &FxHashMap, - catch_residual_throws: &FxHashMap>, - aliases: &HashMap, + body: &'a ExprBody, + expressions: &'a FxHashMap, + catch_residual_throws: &'a FxHashMap>, + aliases: &'a HashMap, include_typevars: bool, unknown_on_unresolved_call: bool, - out: &mut BTreeSet, -) { - match &body.exprs[expr_id] { - Expr::Throw { value } => { - collect_effective_throws_from_expr( - db, - package_id, - *value, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - collect_throw_facts_from_value(*value, body, expressions, out); - } - Expr::Call { callee, args } => { - collect_effective_throws_from_expr( - db, - package_id, - *callee, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - for arg in args { - collect_effective_throws_from_expr( - db, - package_id, - *arg, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); +} + +impl EffectiveThrowsContext<'_, '_> { + fn collect_effective_throws_from_root_expr(&self, root_expr: ExprId) -> BTreeSet { + let mut out = BTreeSet::new(); + self.collect_effective_throws_from_expr(root_expr, &mut out); + out + } + + fn collect_effective_throws_from_expr(&self, expr_id: ExprId, out: &mut BTreeSet) { + match &self.body.exprs[expr_id] { + Expr::Throw { value } => { + self.collect_effective_throws_from_expr(*value, out); + collect_throw_facts_from_value(*value, self.body, self.expressions, out); } - collect_effective_throws_from_call( - db, - package_id, - *callee, - body, - expressions, - aliases, - CallResolutionOptions { - include_typevars, - unknown_on_unresolved_call, - }, - out, - ); - } - Expr::Catch { clauses, .. } => { - if let Some(residual) = catch_residual_throws.get(&expr_id) { - out.extend(residual.iter().cloned()); + Expr::Call { callee, args } | Expr::OptionalCall { callee, args } => { + self.collect_effective_throws_from_expr(*callee, out); + for arg in args { + self.collect_effective_throws_from_expr(*arg, out); + } + self.collect_effective_throws_from_call(*callee, out); } - for clause in clauses { - for arm_id in &clause.arms { - let arm = &body.catch_arms[*arm_id]; - collect_effective_throws_from_expr( - db, - package_id, - arm.body, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::Catch { clauses, .. } => { + if let Some(residual) = self.catch_residual_throws.get(&expr_id) { + out.extend(residual.iter().cloned()); + } + for clause in clauses { + for arm_id in &clause.arms { + let arm = &self.body.catch_arms[*arm_id]; + self.collect_effective_throws_from_expr(arm.body, out); + } } } - } - Expr::If { - condition, - then_branch, - else_branch, - } => { - collect_effective_throws_from_expr( - db, - package_id, - *condition, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - collect_effective_throws_from_expr( - db, - package_id, - *then_branch, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - if let Some(else_expr) = else_branch { - collect_effective_throws_from_expr( - db, - package_id, - *else_expr, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::If { + condition, + then_branch, + else_branch, + } => { + self.collect_effective_throws_from_expr(*condition, out); + self.collect_effective_throws_from_expr(*then_branch, out); + if let Some(else_expr) = else_branch { + self.collect_effective_throws_from_expr(*else_expr, out); + } } - } - Expr::Match { - scrutinee, arms, .. - } => { - collect_effective_throws_from_expr( - db, - package_id, - *scrutinee, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - for arm_id in arms { - let arm = &body.match_arms[*arm_id]; - if let Some(guard) = arm.guard { - collect_effective_throws_from_expr( - db, - package_id, - guard, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::Match { + scrutinee, arms, .. + } => { + self.collect_effective_throws_from_expr(*scrutinee, out); + for arm_id in arms { + let arm = &self.body.match_arms[*arm_id]; + if let Some(guard) = arm.guard { + self.collect_effective_throws_from_expr(guard, out); + } + self.collect_effective_throws_from_expr(arm.body, out); } - collect_effective_throws_from_expr( - db, - package_id, - arm.body, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); } - } - Expr::Binary { lhs, rhs, .. } => { - collect_effective_throws_from_expr( - db, - package_id, - *lhs, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - collect_effective_throws_from_expr( - db, - package_id, - *rhs, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - } - Expr::Unary { expr, .. } => { - collect_effective_throws_from_expr( - db, - package_id, - *expr, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - } - Expr::Object { - fields, spreads, .. - } => { - for (_, value) in fields { - collect_effective_throws_from_expr( - db, - package_id, - *value, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::Binary { lhs, rhs, .. } => { + self.collect_effective_throws_from_expr(*lhs, out); + self.collect_effective_throws_from_expr(*rhs, out); } - for spread in spreads { - collect_effective_throws_from_expr( - db, - package_id, - spread.expr, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::Unary { expr, .. } | Expr::OptionalChain { expr } => { + self.collect_effective_throws_from_expr(*expr, out); } - } - Expr::Array { elements } => { - for elem in elements { - collect_effective_throws_from_expr( - db, - package_id, - *elem, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::Object { + fields, spreads, .. + } => { + for (_, value) in fields { + self.collect_effective_throws_from_expr(*value, out); + } + for spread in spreads { + self.collect_effective_throws_from_expr(spread.expr, out); + } } - } - Expr::Map { entries } => { - for (key, value) in entries { - collect_effective_throws_from_expr( - db, - package_id, - *key, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - collect_effective_throws_from_expr( - db, - package_id, - *value, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::Array { elements } => { + for elem in elements { + self.collect_effective_throws_from_expr(*elem, out); + } } - } - Expr::Block { stmts, tail_expr } => { - for stmt_id in stmts { - collect_effective_throws_from_stmt( - db, - package_id, - *stmt_id, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::Map { entries } => { + for (key, value) in entries { + self.collect_effective_throws_from_expr(*key, out); + self.collect_effective_throws_from_expr(*value, out); + } } - if let Some(tail) = tail_expr { - collect_effective_throws_from_expr( - db, - package_id, - *tail, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::Block { stmts, tail_expr } => { + for stmt_id in stmts { + self.collect_effective_throws_from_stmt(*stmt_id, out); + } + if let Some(tail) = tail_expr { + self.collect_effective_throws_from_expr(*tail, out); + } } - } - Expr::FieldAccess { base, .. } | Expr::OptionalFieldAccess { base, .. } => { - collect_effective_throws_from_expr( - db, - package_id, - *base, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - } - Expr::Index { base, index } | Expr::OptionalIndex { base, index } => { - collect_effective_throws_from_expr( - db, - package_id, - *base, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - collect_effective_throws_from_expr( - db, - package_id, - *index, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - } - Expr::OptionalCall { callee, args } => { - collect_effective_throws_from_expr( - db, - package_id, - *callee, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - for arg in args { - collect_effective_throws_from_expr( - db, - package_id, - *arg, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::FieldAccess { base, .. } | Expr::OptionalFieldAccess { base, .. } => { + self.collect_effective_throws_from_expr(*base, out); } - collect_effective_throws_from_call( - db, - package_id, - *callee, - body, - expressions, - aliases, - CallResolutionOptions { - include_typevars, - unknown_on_unresolved_call, - }, - out, - ); - } - Expr::OptionalChain { expr } => { - collect_effective_throws_from_expr( - db, - package_id, - *expr, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Expr::Index { base, index } | Expr::OptionalIndex { base, index } => { + self.collect_effective_throws_from_expr(*base, out); + self.collect_effective_throws_from_expr(*index, out); + } + Expr::Lambda(_) + | Expr::Literal(_) + | Expr::ByteStringLiteral(_) + | Expr::Null + | Expr::Path(_) + | Expr::Missing => {} } - Expr::Lambda(_) - | Expr::Literal(_) - | Expr::ByteStringLiteral(_) - | Expr::Null - | Expr::Path(_) - | Expr::Missing => {} } -} -#[allow(clippy::too_many_arguments)] -fn collect_effective_throws_from_stmt<'db>( - db: &'db dyn crate::Db, - package_id: PackageId<'db>, - stmt_id: StmtId, - body: &ExprBody, - expressions: &FxHashMap, - catch_residual_throws: &FxHashMap>, - aliases: &HashMap, - include_typevars: bool, - unknown_on_unresolved_call: bool, - out: &mut BTreeSet, -) { - match &body.stmts[stmt_id] { - Stmt::Expr(expr) => collect_effective_throws_from_expr( - db, - package_id, - *expr, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ), - Stmt::Let { initializer, .. } => { - if let Some(init) = initializer { - collect_effective_throws_from_expr( - db, - package_id, - *init, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + fn collect_effective_throws_from_stmt(&self, stmt_id: StmtId, out: &mut BTreeSet) { + match &self.body.stmts[stmt_id] { + Stmt::Expr(expr) => self.collect_effective_throws_from_expr(*expr, out), + Stmt::Let { initializer, .. } => { + if let Some(init) = initializer { + self.collect_effective_throws_from_expr(*init, out); + } } - } - Stmt::While { - condition, - body: while_body, - after, - .. - } => { - collect_effective_throws_from_expr( - db, - package_id, - *condition, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - collect_effective_throws_from_expr( - db, - package_id, - *while_body, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - if let Some(after_stmt) = after { - collect_effective_throws_from_stmt( - db, - package_id, - *after_stmt, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Stmt::While { + condition, + body: while_body, + after, + .. + } => { + self.collect_effective_throws_from_expr(*condition, out); + self.collect_effective_throws_from_expr(*while_body, out); + if let Some(after_stmt) = after { + self.collect_effective_throws_from_stmt(*after_stmt, out); + } } - } - Stmt::For { - collection, - body: for_body, - .. - } => { - collect_effective_throws_from_expr( - db, - package_id, - *collection, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - collect_effective_throws_from_expr( - db, - package_id, - *for_body, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - } - Stmt::Return(expr) => { - if let Some(expr) = expr { - collect_effective_throws_from_expr( - db, - package_id, - *expr, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); + Stmt::For { + collection, + body: for_body, + .. + } => { + self.collect_effective_throws_from_expr(*collection, out); + self.collect_effective_throws_from_expr(*for_body, out); } + Stmt::Return(expr) => { + if let Some(expr) = expr { + self.collect_effective_throws_from_expr(*expr, out); + } + } + Stmt::Assign { target, value } | Stmt::AssignOp { target, value, .. } => { + self.collect_effective_throws_from_expr(*target, out); + self.collect_effective_throws_from_expr(*value, out); + } + Stmt::Throw { value } => { + self.collect_effective_throws_from_expr(*value, out); + collect_throw_facts_from_value(*value, self.body, self.expressions, out); + } + Stmt::Break | Stmt::Continue | Stmt::Missing | Stmt::HeaderComment { .. } => {} } - Stmt::Assign { target, value } | Stmt::AssignOp { target, value, .. } => { - collect_effective_throws_from_expr( - db, - package_id, - *target, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - collect_effective_throws_from_expr( - db, - package_id, - *value, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - } - Stmt::Throw { value } => { - collect_effective_throws_from_expr( - db, - package_id, - *value, - body, - expressions, - catch_residual_throws, - aliases, - include_typevars, - unknown_on_unresolved_call, - out, - ); - collect_throw_facts_from_value(*value, body, expressions, out); - } - Stmt::Break | Stmt::Continue | Stmt::Missing | Stmt::HeaderComment { .. } => {} } -} -#[allow(clippy::too_many_arguments)] -fn collect_effective_throws_from_call<'db>( - db: &'db dyn crate::Db, - package_id: PackageId<'db>, - callee_expr_id: ExprId, - body: &ExprBody, - expressions: &FxHashMap, - aliases: &HashMap, - options: CallResolutionOptions, - out: &mut BTreeSet, -) { - let type_level_facts = expressions - .get(&callee_expr_id) - .and_then(|ty| function_throws_facts(ty, aliases)); + fn collect_effective_throws_from_call(&self, callee_expr_id: ExprId, out: &mut BTreeSet) { + let type_level_facts = self + .expressions + .get(&callee_expr_id) + .and_then(|ty| function_throws_facts(ty, self.aliases)); - if let Some(facts) = type_level_facts.as_ref() { - let filtered: BTreeSet = facts - .iter() - .filter(|fact| options.include_typevars || !matches!(fact, Ty::TypeVar(_, _))) - .cloned() - .collect(); - if !filtered.is_empty() { - out.extend(filtered); - return; - } - if facts.iter().any(|fact| matches!(fact, Ty::TypeVar(_, _))) && !options.include_typevars { - return; + if let Some(facts) = type_level_facts.as_ref() { + let filtered: BTreeSet = facts + .iter() + .filter(|fact| self.include_typevars || !matches!(fact, Ty::TypeVar(_, _))) + .cloned() + .collect(); + if !filtered.is_empty() { + out.extend(filtered); + return; + } + if facts.iter().any(|fact| matches!(fact, Ty::TypeVar(_, _))) && !self.include_typevars + { + return; + } } - } - if let Some(target) = call_target_name(callee_expr_id, body, expressions) { - let throws = function_throw_sets(db, package_id); - if let Some(transitive) = throws.transitive_for(&target) { - out.extend(transitive.iter().cloned()); - return; + if let Some(target) = call_target_name(callee_expr_id, self.body, self.expressions) { + let throws = function_throw_sets(self.db, self.package_id); + if let Some(transitive) = throws.transitive_for(&target) { + out.extend(transitive.iter().cloned()); + return; + } } - } - if options.unknown_on_unresolved_call && type_level_facts.is_none() { - out.insert(Ty::Unknown { - attr: TyAttr::default(), - }); + if self.unknown_on_unresolved_call && type_level_facts.is_none() { + out.insert(Ty::Unknown { + attr: TyAttr::default(), + }); + } } } diff --git a/baml_language/crates/baml_compiler2_tir/src/inference.rs b/baml_language/crates/baml_compiler2_tir/src/inference.rs index 9defde50fc..42b64acbe2 100644 --- a/baml_language/crates/baml_compiler2_tir/src/inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/inference.rs @@ -26,7 +26,7 @@ use text_size::TextRange; use crate::{ builder::TypeInferenceBuilder, - callable_boundary::lower_callable_boundary, + callable_boundary::{LoweredCallableBoundary, lower_callable_boundary}, infer_context::{InferContext, TypeCheckDiagnostics}, ty::{Ty, TyAttr}, }; @@ -358,40 +358,37 @@ pub fn infer_scope_types<'db>( }) }); - let boundary = lower_callable_boundary( + let LoweredCallableBoundary { + params, + ret: return_ty, + explicit_throws, + direct_callback_effect_vars: _, + param_diagnostics, + ret_diagnostics, + throws_diagnostics, + } = lower_callable_boundary( db, pkg_items, &pkg_info.namespace_path, &generic_params, sig.as_ref(), self_param_ty.as_ref(), - &mut Vec::new(), + ); + let sig_sm = baml_compiler2_hir::signature::function_signature_source_map( + db, func_loc, ); - // Get declared return type - let mut diags = Vec::new(); - if let Some(te) = sig.return_type.as_ref() { - let _ = crate::lower_type_expr::lower_type_expr_in_ns( - db, - te, - pkg_items, - &pkg_info.namespace_path, - &generic_params, - &mut diags, - ); + // Report named signature lowering diagnostics at the signature spans. + if !ret_diagnostics.is_empty() { + let span = sig_sm.return_type_span.unwrap_or(func_data.span); + for diag in ret_diagnostics { + builder.report_at_span(diag, span); + } } - let return_ty = boundary.ret.clone(); - - // Report unresolved type diagnostics for return type - if !diags.is_empty() { - let sig_sm = - baml_compiler2_hir::signature::function_signature_source_map( - db, func_loc, - ); - if let Some(ret_span) = sig_sm.return_type_span { - for diag in diags.drain(..) { - builder.report_at_span(diag, ret_span); - } + if !throws_diagnostics.is_empty() { + let span = sig_sm.throws_type_span.unwrap_or(func_data.span); + for diag in throws_diagnostics { + builder.report_at_span(diag, span); } } @@ -399,40 +396,23 @@ pub fn infer_scope_types<'db>( builder.set_return_type(return_ty.clone()); // Add parameter bindings as locals - let sig_sm = baml_compiler2_hir::signature::function_signature_source_map( - db, func_loc, - ); - for (i, ((param_name, param_te), (_, param_ty))) in - sig.params.iter().zip(boundary.params.iter()).enumerate() + for (i, (((param_name, _), (_, param_ty)), param_diags)) in sig + .params + .iter() + .zip(params.iter()) + .zip(param_diagnostics.into_iter()) + .enumerate() { - if !(param_name.as_str() == "self" - && matches!(param_te, baml_compiler2_ast::TypeExpr::Unknown { .. })) - { - let mut param_diags = Vec::new(); - let mut param_effect_vars = Vec::new(); - let _ = crate::lower_type_expr::lower_type_expr_with_fn_context( - db, - param_te, - pkg_items, - &pkg_info.namespace_path, - &generic_params, - &mut param_diags, - &crate::lower_type_expr::FnTypeLoweringContext::DirectParamRoot { - param_name: param_name.clone(), - }, - &mut param_effect_vars, - ); - if !param_diags.is_empty() { - let span = sig_sm - .param_type_spans - .get(i) - .copied() - .flatten() - .or_else(|| sig_sm.param_spans.get(i).copied()) - .unwrap_or_default(); - for diag in param_diags { - builder.report_at_span(diag, span); - } + if !param_diags.is_empty() { + let span = sig_sm + .param_type_spans + .get(i) + .copied() + .flatten() + .or_else(|| sig_sm.param_spans.get(i).copied()) + .unwrap_or(func_data.span); + for diag in param_diags { + builder.report_at_span(diag, span); } } builder.add_local(param_name.clone(), param_ty.clone()); @@ -444,9 +424,9 @@ pub fn infer_scope_types<'db>( } // Validate declared `throws` against effective escaping throws. - builder.check_throws_contract( + builder.check_lowered_throws_contract( expr_body, - sig.throws.as_ref(), + explicit_throws.as_ref(), sig_sm.throws_type_span, func_data.span, ); diff --git a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs index 4d64faa995..697136edce 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs @@ -521,15 +521,7 @@ fn collect_direct_param_call_throws<'db>( body: &ExprBody, aliases: &HashMap, ) -> BTreeSet { - let boundary = lower_callable_boundary( - db, - pkg_items, - ns_context, - generic_params, - sig, - None, - &mut Vec::new(), - ); + let boundary = lower_callable_boundary(db, pkg_items, ns_context, generic_params, sig, None); let directly_invoked = directly_invoked_callback_params(body); if directly_invoked.is_empty() { diff --git a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs index fc0544af8e..3d3fccc642 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs @@ -269,6 +269,9 @@ fn format_postfix_base_for_diagnostic(ty: &Ty) -> String { } pub(crate) fn format_ty_for_diagnostic(ty: &Ty) -> String { + // NOTE: Keep this in sync with `impl Display for Ty`, except for the + // throws-specific user-facing translations (for example `__throws_*` type + // vars and the diagnostic-only throws rendering). match ty { Ty::Class(qn, _) | Ty::Enum(qn, _) | Ty::TypeAlias(qn, _) => qn.to_string(), Ty::EnumVariant(qn, variant, _) => format!("{qn}.{variant}"), diff --git a/baml_language/crates/baml_compiler2_tir/src/ty.rs b/baml_language/crates/baml_compiler2_tir/src/ty.rs index 90c5777959..c716f7b3bd 100644 --- a/baml_language/crates/baml_compiler2_tir/src/ty.rs +++ b/baml_language/crates/baml_compiler2_tir/src/ty.rs @@ -454,6 +454,9 @@ impl Ty { impl fmt::Display for Ty { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // NOTE: Keep this in sync with `throws_semantics::format_ty_for_diagnostic`. + // That formatter intentionally layers throws-specific user-facing wording + // on top of the structural rendering here. match self { Ty::Class(qn, _) => write!(f, "{qn}"), Ty::Enum(qn, _) => write!(f, "{qn}"), diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/nested_function_throws_validation.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/nested_function_throws_validation.baml new file mode 100644 index 0000000000..36c40c59ba --- /dev/null +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/nested_function_throws_validation.baml @@ -0,0 +1,28 @@ +function BadNestedBuiltin(cb: () -> int throws $rust_type) -> int { + cb() +} + +function BadNestedUnknownAttr(cb: () -> int throws string @unknown) -> int { + cb() +} + +//---- +//- diagnostics +// Error: Builtin-only syntax `$rust_type` is only allowed in builtin stdlib files +// ╭─[ throw_nested_function_throws_validation.baml:1:30 ] +// │ +// 1 │ function BadNestedBuiltin(cb: () -> int throws $rust_type) -> int { +// │ ──────────────┬───────────── +// │ ╰─────────────── builtin-only syntax +// │ +// │ Note: Error code: E0016 +// ───╯ +// Error: Unknown attribute `@unknown` +// ╭─[ throw_nested_function_throws_validation.baml:5:59 ] +// │ +// 5 │ function BadNestedUnknownAttr(cb: () -> int throws string @unknown) -> int { +// │ ────┬─── +// │ ╰───── unknown attribute +// │ +// │ Note: Error code: E0015 +// ───╯ diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml index ecf2564709..8857bccb2e 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml @@ -68,6 +68,17 @@ function test_ignore_callback_throwing() -> int { ignore_callback(() -> int { throw "unused" }) } +// Aliased callback invocation: the rethrow is NOT propagated because the +// syntactic scan only recognizes direct `f()` calls. +function apply_aliased(f: () -> int) -> int { + let g = f + g() +} + +function test_apply_aliased_throwing() -> int { + apply_aliased(() -> int { throw "aliased" }) +} + // Concrete body throws still escape even when the callback is never invoked. function ignore_callback_but_throw(f: () -> int) -> int { helper_with_body_throw() diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap index 19c378b35f..492b8efd86 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 11105 --- Slash "/" Slash "/" @@ -389,6 +390,75 @@ RParen ")" RBrace "}" Slash "/" Slash "/" +Word "Aliased" +Word "callback" +Word "invocation" +Colon ":" +Word "the" +Word "rethrow" +Word "is" +Word "NOT" +Word "propagated" +Word "because" +Word "the" +Slash "/" +Slash "/" +Word "syntactic" +Word "scan" +Word "only" +Word "recognizes" +Word "direct" +Error "`" +Word "f" +LParen "(" +RParen ")" +Error "`" +Word "calls" +Dot "." +Function "function" +Word "apply_aliased" +LParen "(" +Word "f" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "g" +Equals "=" +Word "f" +Word "g" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_apply_aliased_throwing" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "apply_aliased" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Throw "throw" +Quote "\"" +Word "aliased" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" Word "Concrete" Word "body" Throws "throws" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap index d975f965ff..b232607831 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__hof_throws.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 11796 --- === SYNTAX TREE === SOURCE_FILE @@ -446,6 +447,73 @@ SOURCE_FILE R_BRACE "}" R_PAREN ")" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "apply_aliased" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "f" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let g = f" + KW_LET "let" + WORD "g" + EQUALS "=" + WORD "f" + CALL_EXPR + WORD "g" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_apply_aliased_throwing" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "apply_aliased" + CALL_ARGS + L_PAREN "(" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "aliased" + QUOTE """ + WORD "aliased" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" WORD "ignore_callback_but_throw" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index cf15232a4d..ff052a5853 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 12292 --- === HIR2 === function user.test_array_map_pure() -> int[] [expr] { @@ -269,6 +270,9 @@ function user.test_two_pure() -> int [expr] { function user.apply(f: () -> int) -> int [expr] { { } f() } +function user.apply_aliased(f: () -> int) -> int [expr] { + { let g = f } g() +} function user.apply_and_throw(f: () -> int) -> int [expr] { { let result = f(); if (result Lt 0) { throw "negative result" } } result } @@ -290,6 +294,9 @@ function user.ignore_callback(f: () -> int) -> int [expr] { function user.ignore_callback_but_throw(f: () -> int) -> int [expr] { { } helper_with_body_throw() } +function user.test_apply_aliased_throwing() -> int [expr] { + { } apply_aliased(() -> int { { throw "aliased" } }) +} function user.test_apply_and_throw_pure() -> int [expr] { { } apply_and_throw(() -> int { { } 42 }) } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index c05f0a9734..3db31776d4 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 12668 --- === MIR2 === fn user.test_array_map_pure() -> int[] { @@ -2297,6 +2298,28 @@ fn user.apply(f: () -> int) -> int { } } +fn user.apply_aliased(f: () -> int) -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> int // f // param + let _2: () -> int // g + let _3: () -> int + + bb0: { + _2 = copy _1; + _3 = copy _2; + _0 = call copy _3() -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + fn user.apply_and_throw(f: () -> int) -> int { // Locals: let _0: int // _0 // return @@ -2434,6 +2457,35 @@ fn user.ignore_callback_but_throw(f: () -> int) -> int { } } +fn user.test_apply_aliased_throwing() -> int { + // Locals: + let _0: int // _0 // return + let _1: () -> void throws string + + bb0: { + _1 = make_closure lambda[0](); + _0 = call const fn user.apply_aliased(copy _1) -> [bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "aliased"; + } +} + fn user.test_apply_and_throw_pure() -> int { // Locals: let _0: int // _0 // return diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 2220b16135..b165fcc16c 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 12476 --- === TIR2 === function user.test_array_map_pure() -> int[] throws never { @@ -901,6 +902,23 @@ function user.test_ignore_callback_throwing() -> int throws never { } lambda user.test_ignore_callback_throwing { } +function user.apply_aliased(f: () -> int throws __throws_f) -> int throws __throws_f { + { : int + let g = f : () -> int throws __throws_f + g() : int + } +} +function user.test_apply_aliased_throwing() -> int throws never { + { : int + apply_aliased(() -> int { ... }) : int + () -> int { ... } : () -> never throws string + { + throw "aliased" + } + } +} +lambda user.test_apply_aliased_throwing { +} function user.ignore_callback_but_throw(f: () -> int throws __throws_f) -> int throws __throws_f | string { { : int helper_with_body_throw() : int diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 1f96421dfa..cafd3c34d1 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 13077 --- function user.MethodRunner.apply(self: null, f: (void) -> int) -> int { load_var self @@ -28,6 +29,12 @@ function user.apply(f: () -> int) -> int { return } +function user.apply_aliased(f: () -> int) -> int { + load_var f + call_indirect + return +} + function user.apply_and_throw(f: () -> int) -> int { load_var f call_indirect @@ -483,6 +490,12 @@ function user.takes_throwing(f: () -> int throws string) -> int { return } +function user.test_apply_aliased_throwing() -> int { + make_closure ., 0 + call user.apply_aliased + return +} + function user.test_apply_and_throw_pure() -> int { make_closure ., 0 call user.apply_and_throw From 73cdd0f1bcc985da2f5419cdcf703fa0d041fd93 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 01:03:27 -0500 Subject: [PATCH 14/26] Refresh function_type_throws MIR and codegen snapshots --- ...tests__function_type_throws__04_5_mir.snap | 761 +++--------------- ...sts__function_type_throws__06_codegen.snap | 68 +- 2 files changed, 139 insertions(+), 690 deletions(-) diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index 3db31776d4..f8da30a8b9 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 12668 --- === MIR2 === fn user.test_array_map_pure() -> int[] { @@ -18,10 +17,6 @@ fn user.test_array_map_pure() -> int[] { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -57,10 +52,6 @@ fn user.test_array_map_throwing() -> int[] { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -74,23 +65,19 @@ fn .(x: null) -> null { bb0: { _2 = copy _1 == const 2_i64; - branch copy _2 -> [bb4, bb1]; + branch copy _2 -> [bb3, bb1]; } bb1: { + _0 = copy _1 * const 2_i64; goto -> bb2; } bb2: { - _0 = copy _1 * const 2_i64; - goto -> bb3; - } - - bb3: { return; } - bb4: { + bb3: { throw const "found two"; } } @@ -105,10 +92,6 @@ fn user.takes_throwing(f: () -> int throws string) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -121,23 +104,19 @@ fn user.caught_may_fail(x: int) -> string { bb0: { _2 = copy _1 == const 0_i64; - branch copy _2 -> [bb4, bb1]; + branch copy _2 -> [bb3, bb1]; } bb1: { + _0 = const "ok"; goto -> bb2; } bb2: { - _0 = const "ok"; - goto -> bb3; - } - - bb3: { return; } - bb4: { + bb3: { throw const "zero"; } } @@ -148,31 +127,19 @@ fn user.test_catch_all() -> string { let _1: unknown // e bb0: { - _0 = call const fn user.caught_may_fail(const 1_i64) -> [bb1, unwind: bb2]; + _0 = call const fn user.caught_may_fail(const 1_i64) -> [bb3, unwind: bb1]; } bb1: { - goto -> bb5; + throw_if_panic copy _1 -> bb2; } bb2: { - throw_if_panic copy _1 -> bb3; - } - - bb3: { - goto -> bb4; - } - - bb4: { _0 = const "caught"; - goto -> bb5; - } - - bb5: { - goto -> bb6; + goto -> bb3; } - bb6: { + bb3: { return; } } @@ -183,30 +150,18 @@ fn user.test_catch_rethrow() -> string { let _1: unknown // e bb0: { - _0 = call const fn user.caught_may_fail(const 1_i64) -> [bb1, unwind: bb4]; + _0 = call const fn user.caught_may_fail(const 1_i64) -> [bb1, unwind: bb2]; } bb1: { - goto -> bb2; + return; } bb2: { - goto -> bb3; + throw_if_panic copy _1 -> bb3; } bb3: { - return; - } - - bb4: { - throw_if_panic copy _1 -> bb5; - } - - bb5: { - goto -> bb6; - } - - bb6: { throw const 42_i64; } } @@ -218,32 +173,24 @@ fn user.test_catch_string() -> string { let _2: bool bb0: { - _0 = call const fn user.caught_may_fail(const 1_i64) -> [bb1, unwind: bb2]; + _0 = call const fn user.caught_may_fail(const 1_i64) -> [bb4, unwind: bb1]; } bb1: { - goto -> bb5; - } - - bb2: { _2 = is_type(copy _1, String { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); - branch copy _2 -> [bb4, bb3]; + branch copy _2 -> [bb3, bb2]; } - bb3: { + bb2: { throw copy _1; } - bb4: { + bb3: { _0 = const "caught string"; - goto -> bb5; - } - - bb5: { - goto -> bb6; + goto -> bb4; } - bb6: { + bb4: { return; } } @@ -257,10 +204,6 @@ fn user.test_no_catch() -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -275,10 +218,6 @@ fn user.apply_inner(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -293,10 +232,6 @@ fn user.apply_outer(g: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -312,10 +247,6 @@ fn user.test_chained_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -346,10 +277,6 @@ fn user.test_chained_throwing() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -377,10 +304,6 @@ fn user.MethodRunner.apply(self: MethodRunner, f: (void) -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -398,10 +321,6 @@ fn user.MethodRunner.apply_underdeclared(self: MethodRunner, f: (void) -> int) - } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -478,10 +397,6 @@ fn user.test_method_runner() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -538,10 +453,6 @@ fn .(a: void) -> null { } bb2: { - goto -> bb3; - } - - bb3: { return; } } @@ -559,10 +470,6 @@ fn user.test_compose_both_throw() -> (int) -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -602,10 +509,6 @@ fn user.test_compose_first_throws() -> (int) -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -650,10 +553,6 @@ fn user.test_compose_pure() -> (int) -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -703,10 +602,6 @@ fn user.test_compose_second_throws() -> (int) -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -748,10 +643,6 @@ fn user.apply_generic_class(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -766,10 +657,6 @@ fn user.apply_generic_enum(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -784,10 +671,6 @@ fn user.apply_may_throw_class(f: () -> int throws ApiError) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -802,10 +685,6 @@ fn user.apply_may_throw_enum(f: () -> int throws ErrorKind) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -820,34 +699,26 @@ fn user.catch_class_error(x: int) -> string { let _5: ApiError bb0: { - _0 = call const fn user.throw_class_instance(copy _1) -> [bb1, unwind: bb2]; + _0 = call const fn user.throw_class_instance(copy _1) -> [bb4, unwind: bb1]; } bb1: { - goto -> bb5; - } - - bb2: { _3 = is_type(copy _2, Class(TypeName { name: "ApiError", module_path: ["user"], display_name: "ApiError" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] })); - branch copy _3 -> [bb4, bb3]; + branch copy _3 -> [bb3, bb2]; } - bb3: { + bb2: { throw copy _2; } - bb4: { + bb3: { _5 = copy _2; _4 = copy _5.1; _0 = const "api error: " + copy _4; - goto -> bb5; - } - - bb5: { - goto -> bb6; + goto -> bb4; } - bb6: { + bb4: { return; } } @@ -857,55 +728,37 @@ fn user.catch_enum_variants(x: int) -> string { let _0: string // _0 // return let _1: int // x // param let _2: unknown // e - let _3: bool - let _4: bool + let _3: int bb0: { - _0 = call const fn user.throw_enum_variant(copy _1) -> [bb1, unwind: bb2]; + _0 = call const fn user.throw_enum_variant(copy _1) -> [bb6, unwind: bb1]; } bb1: { - goto -> bb9; + _3 = discriminant(_2); + switch copy _3 [ErrorKind.NotFound: bb5, ErrorKind.Unauthorized: bb4, otherwise: bb2]; } bb2: { - _3 = copy _2 == const user.ErrorKind.NotFound; - branch copy _3 -> [bb8, bb3]; + throw_if_panic copy _2 -> bb3; } bb3: { - _4 = copy _2 == const user.ErrorKind.Unauthorized; - branch copy _4 -> [bb7, bb4]; + _0 = const "other error"; + goto -> bb6; } bb4: { - throw_if_panic copy _2 -> bb5; + _0 = const "unauthorized"; + goto -> bb6; } bb5: { + _0 = const "not found"; goto -> bb6; } bb6: { - _0 = const "other error"; - goto -> bb9; - } - - bb7: { - _0 = const "unauthorized"; - goto -> bb9; - } - - bb8: { - _0 = const "not found"; - goto -> bb9; - } - - bb9: { - goto -> bb10; - } - - bb10: { return; } } @@ -919,51 +772,39 @@ fn user.catch_mixed_errors(x: int) -> string { let _4: bool bb0: { - _0 = call const fn user.throw_enum_or_class(copy _1) -> [bb1, unwind: bb2]; + _0 = call const fn user.throw_enum_or_class(copy _1) -> [bb7, unwind: bb1]; } bb1: { - goto -> bb9; + _3 = copy _2 == const user.ErrorKind.RateLimited; + branch copy _3 -> [bb6, bb2]; } bb2: { - _3 = copy _2 == const user.ErrorKind.RateLimited; - branch copy _3 -> [bb8, bb3]; + _4 = is_type(copy _2, Class(TypeName { name: "ApiError", module_path: ["user"], display_name: "ApiError" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] })); + branch copy _4 -> [bb5, bb3]; } bb3: { - _4 = is_type(copy _2, Class(TypeName { name: "ApiError", module_path: ["user"], display_name: "ApiError" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] })); - branch copy _4 -> [bb7, bb4]; + throw_if_panic copy _2 -> bb4; } bb4: { - throw_if_panic copy _2 -> bb5; - } - - bb5: { - goto -> bb6; - } - - bb6: { _0 = const "unknown"; - goto -> bb9; + goto -> bb7; } - bb7: { + bb5: { _0 = const "api error"; - goto -> bb9; + goto -> bb7; } - bb8: { + bb6: { _0 = const "rate limited"; - goto -> bb9; - } - - bb9: { - goto -> bb10; + goto -> bb7; } - bb10: { + bb7: { return; } } @@ -1033,10 +874,6 @@ fn user.test_apply_class_thrower() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1064,10 +901,6 @@ fn user.test_apply_enum_thrower() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1095,10 +928,6 @@ fn user.test_lambda_throws_class() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1128,10 +957,6 @@ fn user.test_lambda_throws_enum() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1157,10 +982,6 @@ fn user.test_rethrows_class() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1188,10 +1009,6 @@ fn user.test_rethrows_enum() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1217,10 +1034,6 @@ fn user.test_type_alias_class() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1248,10 +1061,6 @@ fn user.test_type_alias_enum() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1277,10 +1086,6 @@ fn user.test_type_alias_mixed() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1302,7 +1107,7 @@ fn user.throw_any_error(x: int) -> string { let _2: ApiError bb0: { - switch copy _1 [0: bb6, 1: bb5, 2: bb4, otherwise: bb1]; + switch copy _1 [0: bb5, 1: bb4, 2: bb3, otherwise: bb1]; } bb1: { @@ -1311,23 +1116,19 @@ fn user.throw_any_error(x: int) -> string { } bb2: { - goto -> bb3; - } - - bb3: { return; } - bb4: { + bb3: { _2 = ApiError { const 400_i64, const "bad request" }; throw copy _2; } - bb5: { + bb4: { throw const user.ErrorKind.NotFound; } - bb6: { + bb5: { throw const "simple string error"; } } @@ -1341,23 +1142,19 @@ fn user.throw_class_instance(x: int) -> string { bb0: { _2 = copy _1 < const 0_i64; - branch copy _2 -> [bb4, bb1]; + branch copy _2 -> [bb3, bb1]; } bb1: { + _0 = const "ok"; goto -> bb2; } bb2: { - _0 = const "ok"; - goto -> bb3; - } - - bb3: { return; } - bb4: { + bb3: { _3 = ApiError { const 500_i64, const "internal error" }; throw copy _3; } @@ -1373,37 +1170,29 @@ fn user.throw_enum_or_class(x: int) -> string { bb0: { _2 = copy _1 == const 0_i64; - branch copy _2 -> [bb7, bb1]; + branch copy _2 -> [bb5, bb1]; } bb1: { - goto -> bb2; - } - - bb2: { _3 = copy _1 == const 1_i64; - branch copy _3 -> [bb6, bb3]; + branch copy _3 -> [bb4, bb2]; } - bb3: { - goto -> bb4; - } - - bb4: { + bb2: { _0 = const "ok"; - goto -> bb5; + goto -> bb3; } - bb5: { + bb3: { return; } - bb6: { + bb4: { _4 = ApiError { const 503_i64, const "unavailable" }; throw copy _4; } - bb7: { + bb5: { throw const user.ErrorKind.RateLimited; } } @@ -1417,36 +1206,28 @@ fn user.throw_enum_variant(x: int) -> string { bb0: { _2 = copy _1 == const 0_i64; - branch copy _2 -> [bb7, bb1]; + branch copy _2 -> [bb5, bb1]; } bb1: { - goto -> bb2; - } - - bb2: { _3 = copy _1 == const 1_i64; - branch copy _3 -> [bb6, bb3]; - } - - bb3: { - goto -> bb4; + branch copy _3 -> [bb4, bb2]; } - bb4: { + bb2: { _0 = const "ok"; - goto -> bb5; + goto -> bb3; } - bb5: { + bb3: { return; } - bb6: { + bb4: { throw const user.ErrorKind.Unauthorized; } - bb7: { + bb5: { throw const user.ErrorKind.NotFound; } } @@ -1462,37 +1243,29 @@ fn user.throw_mixed_classes(x: int) -> string { bb0: { _2 = copy _1 == const 0_i64; - branch copy _2 -> [bb7, bb1]; + branch copy _2 -> [bb5, bb1]; } bb1: { - goto -> bb2; - } - - bb2: { _4 = copy _1 == const 1_i64; - branch copy _4 -> [bb6, bb3]; + branch copy _4 -> [bb4, bb2]; } - bb3: { - goto -> bb4; - } - - bb4: { + bb2: { _0 = const "ok"; - goto -> bb5; + goto -> bb3; } - bb5: { + bb3: { return; } - bb6: { + bb4: { _5 = ValidationError { const "id", const "required" }; throw copy _5; } - bb7: { + bb5: { _3 = ApiError { const 404_i64, const "not found" }; throw copy _3; } @@ -1504,7 +1277,7 @@ fn user.throw_various_errors(x: int) -> string { let _1: int // x // param bb0: { - switch copy _1 [0: bb6, 1: bb5, 2: bb4, otherwise: bb1]; + switch copy _1 [0: bb5, 1: bb4, 2: bb3, otherwise: bb1]; } bb1: { @@ -1513,22 +1286,18 @@ fn user.throw_various_errors(x: int) -> string { } bb2: { - goto -> bb3; - } - - bb3: { return; } - bb4: { + bb3: { throw const user.ErrorKind.ValidationFailed; } - bb5: { + bb4: { throw const user.ErrorKind.Unauthorized; } - bb6: { + bb5: { throw const user.ErrorKind.NotFound; } } @@ -1543,10 +1312,6 @@ fn user.use_class_thrower(f: () -> int throws ApiError) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1561,10 +1326,6 @@ fn user.use_enum_thrower(f: () -> int throws ErrorKind) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1579,10 +1340,6 @@ fn user.use_mixed_thrower(f: () -> int throws ErrorKind | ApiError | string) -> } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1597,10 +1354,6 @@ fn user.apply_explicit(f: () -> int throws string) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1615,10 +1368,6 @@ fn user.apply_pure_only(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1634,10 +1383,6 @@ fn user.test_explicit_param_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1668,10 +1413,6 @@ fn user.test_explicit_param_throwing() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1697,10 +1438,6 @@ fn user.test_pure_only_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1744,10 +1481,6 @@ fn user.caller() -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1760,23 +1493,19 @@ fn user.declared_may_fail(x: int) -> string { bb0: { _2 = copy _1 == const 0_i64; - branch copy _2 -> [bb4, bb1]; + branch copy _2 -> [bb3, bb1]; } bb1: { + _0 = const "ok"; goto -> bb2; } bb2: { - _0 = const "ok"; - goto -> bb3; - } - - bb3: { return; } - bb4: { + bb3: { throw const "zero"; } } @@ -1787,31 +1516,19 @@ fn user.safe_caller() -> string { let _1: unknown // e bb0: { - _0 = call const fn user.declared_may_fail(const 1_i64) -> [bb1, unwind: bb2]; + _0 = call const fn user.declared_may_fail(const 1_i64) -> [bb3, unwind: bb1]; } bb1: { - goto -> bb5; + throw_if_panic copy _1 -> bb2; } bb2: { - throw_if_panic copy _1 -> bb3; - } - - bb3: { - goto -> bb4; - } - - bb4: { _0 = const "caught"; - goto -> bb5; - } - - bb5: { - goto -> bb6; + goto -> bb3; } - bb6: { + bb3: { return; } } @@ -1826,10 +1543,6 @@ fn user.use_alias_pure(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1849,23 +1562,19 @@ fn user.apply_guarded(f: () -> int) -> int { bb1: { _4 = copy _2; _3 = copy _4 < const 0_i64; - branch copy _3 -> [bb5, bb2]; + branch copy _3 -> [bb4, bb2]; } bb2: { + _0 = copy _2; goto -> bb3; } bb3: { - _0 = copy _2; - goto -> bb4; - } - - bb4: { return; } - bb5: { + bb4: { throw const "negative result"; } } @@ -1881,10 +1590,6 @@ fn user.test_guarded_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1915,10 +1620,6 @@ fn user.test_guarded_throwing() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1944,10 +1645,6 @@ fn user.map_it(x: int, f: (int) -> string) -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -1961,11 +1658,7 @@ fn user.run_pure(f: () -> int) -> int { _0 = call copy _1() -> [bb1]; } - bb1: { - goto -> bb2; - } - - bb2: { + bb1: { return; } } @@ -1980,10 +1673,6 @@ fn user.run_throwing(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2025,10 +1714,6 @@ fn user.test_map_it_pure() -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2060,10 +1745,6 @@ fn user.test_map_it_throwing() -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2090,10 +1771,6 @@ fn user.test_run_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2124,10 +1801,6 @@ fn user.test_run_throwing() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2155,10 +1828,6 @@ fn user.test_two_both_throw() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2196,10 +1865,6 @@ fn user.test_two_one_throws() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2242,10 +1907,6 @@ fn user.test_two_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2290,10 +1951,6 @@ fn user.apply(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2312,10 +1969,6 @@ fn user.apply_aliased(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2335,23 +1988,19 @@ fn user.apply_and_throw(f: () -> int) -> int { bb1: { _4 = copy _2; _3 = copy _4 < const 0_i64; - branch copy _3 -> [bb5, bb2]; + branch copy _3 -> [bb4, bb2]; } bb2: { + _0 = copy _2; goto -> bb3; } bb3: { - _0 = copy _2; - goto -> bb4; - } - - bb4: { return; } - bb5: { + bb4: { throw const "negative result"; } } @@ -2366,10 +2015,6 @@ fn user.apply_throwing(f: () -> int throws string) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2384,10 +2029,6 @@ fn user.apply_underdeclared(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2407,10 +2048,6 @@ fn user.apply_with_helper(f: () -> int) -> int { } bb2: { - goto -> bb3; - } - - bb3: { return; } } @@ -2449,10 +2086,6 @@ fn user.ignore_callback_but_throw(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2468,10 +2101,6 @@ fn user.test_apply_aliased_throwing() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2497,10 +2126,6 @@ fn user.test_apply_and_throw_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2531,10 +2156,6 @@ fn user.test_apply_explicit_throws() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2560,10 +2181,6 @@ fn user.test_apply_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2594,10 +2211,6 @@ fn user.test_apply_throwing() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2623,10 +2236,6 @@ fn user.test_apply_with_helper_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2657,10 +2266,6 @@ fn user.test_apply_with_helper_throwing() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2686,10 +2291,6 @@ fn user.test_ignore_callback_but_throw() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2715,10 +2316,6 @@ fn user.test_ignore_callback_throwing() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2746,10 +2343,6 @@ fn user.test_explicit_throws_match() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2777,10 +2370,6 @@ fn user.test_explicit_throws_never_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2813,10 +2402,6 @@ fn user.test_explicit_throws_wider_than_body() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2850,10 +2435,6 @@ fn user.test_conditional_throw(x: int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2867,23 +2448,19 @@ fn .(n: int) -> null { bb0: { _2 = copy _1 < const 0_i64; - branch copy _2 -> [bb4, bb1]; + branch copy _2 -> [bb3, bb1]; } bb1: { + _0 = copy _1; goto -> bb2; } bb2: { - _0 = copy _1; - goto -> bb3; - } - - bb3: { return; } - bb4: { + bb3: { throw const "negative"; } } @@ -2902,10 +2479,6 @@ fn user.test_multi_throw_types(x: int) -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2917,7 +2490,7 @@ fn .(n: int) -> null { let _1: int // n // param bb0: { - switch copy _1 [0: bb5, 1: bb4, otherwise: bb1]; + switch copy _1 [0: bb4, 1: bb3, otherwise: bb1]; } bb1: { @@ -2926,18 +2499,14 @@ fn .(n: int) -> null { } bb2: { - goto -> bb3; - } - - bb3: { return; } - bb4: { + bb3: { throw const 1_i64; } - bb5: { + bb4: { throw const "string error"; } } @@ -2955,10 +2524,6 @@ fn user.test_pure_lambda() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -2991,10 +2556,6 @@ fn user.test_throwing_int() -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3022,10 +2583,6 @@ fn user.test_throwing_lambda() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3053,10 +2610,6 @@ fn user.test_throws_never_but_throws() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3084,10 +2637,6 @@ fn user.test_typed_lambda_annotation_underdeclares_body() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3115,10 +2664,6 @@ fn user.test_typed_lambda_throws_mismatch() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3144,10 +2689,6 @@ fn user.apply_with_arg(x: int, f: (int) -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3164,10 +2705,6 @@ fn user.apply_with_many_args(a: int, b: string, f: (int, string) -> string) -> s } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3183,10 +2720,6 @@ fn user.test_many_args_pure() -> string { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3219,10 +2752,6 @@ fn user.test_mixed_pure() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3254,10 +2783,6 @@ fn user.test_mixed_throwing() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3271,23 +2796,19 @@ fn .(n: int) -> null { bb0: { _2 = copy _1 < const 0_i64; - branch copy _2 -> [bb4, bb1]; + branch copy _2 -> [bb3, bb1]; } bb1: { + _0 = copy _1 * const 2_i64; goto -> bb2; } bb2: { - _0 = copy _1 * const 2_i64; - goto -> bb3; - } - - bb3: { return; } - bb4: { + bb3: { throw const "negative"; } } @@ -3305,10 +2826,6 @@ fn user.test_nested_both_throw() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3355,10 +2872,6 @@ fn user.test_nested_inner_throws() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3377,10 +2890,6 @@ fn .() -> null { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3408,10 +2917,6 @@ fn user.test_nested_outer_throws() -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3450,44 +2955,28 @@ fn user.optional_call_caught(cb: (() -> int throws string)?) -> int? { bb0: { _3 = copy _1 == const null; - branch copy _3 -> [bb3, bb1]; + branch copy _3 -> [bb2, bb1]; } bb1: { - _0 = call copy _1() -> [bb2, unwind: bb5]; + _0 = call copy _1() -> [bb5, unwind: bb3]; } bb2: { - goto -> bb4; + _0 = const null; + goto -> bb5; } bb3: { - _0 = const null; - goto -> bb4; + throw_if_panic copy _2 -> bb4; } bb4: { - goto -> bb8; - } - - bb5: { - throw_if_panic copy _2 -> bb6; - } - - bb6: { - goto -> bb7; - } - - bb7: { _0 = const null; - goto -> bb8; - } - - bb8: { - goto -> bb9; + goto -> bb5; } - bb9: { + bb5: { return; } } @@ -3500,27 +2989,19 @@ fn user.optional_call_rethrows(cb: (() -> int throws string)?) -> int? { bb0: { _2 = copy _1 == const null; - branch copy _2 -> [bb3, bb1]; + branch copy _2 -> [bb2, bb1]; } bb1: { - _0 = call copy _1() -> [bb2]; + _0 = call copy _1() -> [bb3]; } bb2: { - goto -> bb4; - } - - bb3: { _0 = const null; - goto -> bb4; - } - - bb4: { - goto -> bb5; + goto -> bb3; } - bb5: { + bb3: { return; } } @@ -3536,10 +3017,6 @@ fn user.test_optional_call_with_throwing_callback() -> int? { } bb1: { - goto -> bb2; - } - - bb2: { return; } } @@ -3623,10 +3100,6 @@ fn user.test_use_pure() -> int { } bb2: { - goto -> bb3; - } - - bb3: { return; } } @@ -3647,10 +3120,6 @@ fn user.test_use_thrower() -> int { } bb2: { - goto -> bb3; - } - - bb3: { return; } } @@ -3829,10 +3298,6 @@ fn user.test_store_direct_callback_param(f: () -> int) -> int { } bb1: { - goto -> bb2; - } - - bb2: { return; } } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index cafd3c34d1..505bc2bcfc 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 13077 --- function user.MethodRunner.apply(self: null, f: (void) -> int) -> int { load_var self @@ -191,21 +190,24 @@ function user.catch_enum_variants(x: int) -> string { call user.throw_enum_variant jump L4 load_var e - load_const user.ErrorKind.NotFound - alloc_variant user.ErrorKind + discriminant + copy 0 + load_const ErrorKind.NotFound cmp_op == pop_jump_if_false L0 + pop 1 jump L3 L0: - load_var e - load_const user.ErrorKind.Unauthorized - alloc_variant user.ErrorKind + copy 0 + load_const ErrorKind.Unauthorized cmp_op == pop_jump_if_false L1 + pop 1 jump L2 L1: + pop 1 load_var e throw_if_panic load_const "other error" @@ -319,33 +321,29 @@ function user.ignore_callback_but_throw(f: () -> int) -> int { function user.make_bad_stored_handler() -> StoredPureHandler { alloc_instance StoredPureHandler - copy 0 make_closure ., 0 - store_field .run + init_field .run return } function user.make_class_handler() -> ClassThrowingHandler { alloc_instance ClassThrowingHandler - copy 0 make_closure ., 0 - store_field .run + init_field .run return } function user.make_enum_handler() -> EnumThrowingHandler { alloc_instance EnumThrowingHandler - copy 0 make_closure ., 0 - store_field .run + init_field .run return } function user.make_good_stored_handler() -> StoredPureHandler { alloc_instance StoredPureHandler - copy 0 make_closure ., 0 - store_field .run + init_field .run return } @@ -356,9 +354,8 @@ function user.make_pure() -> () -> int { function user.make_pure_handler() -> PureHandler { alloc_instance PureHandler - copy 0 make_closure ., 0 - store_field .run + init_field .run return } @@ -379,9 +376,8 @@ function user.make_stored_closure_with_throws() -> () -> int throws string { function user.make_stored_throwing_handler() -> StoredThrowingHandler { alloc_instance StoredThrowingHandler - copy 0 make_closure ., 0 - store_field .run + init_field .run return } @@ -392,9 +388,8 @@ function user.make_thrower() -> () -> int throws string { function user.make_throwing_handler() -> ThrowingHandler { alloc_instance ThrowingHandler - copy 0 make_closure ., 0 - store_field .run + init_field .run return } @@ -748,9 +743,8 @@ function user.test_map_it_throwing() -> string { function user.test_method_runner() -> int { alloc_instance MethodRunner - copy 0 load_const 1 - store_field .value + init_field .value make_closure ., 0 call user.MethodRunner.apply return @@ -992,12 +986,10 @@ function user.throw_any_error(x: int) -> string { L3: alloc_instance ApiError - copy 0 load_const 400 - store_field .code - copy 0 + init_field .code load_const "bad request" - store_field .message + init_field .message throw L4: @@ -1023,12 +1015,10 @@ function user.throw_class_instance(x: int) -> string { L1: alloc_instance ApiError - copy 0 load_const 500 - store_field .code - copy 0 + init_field .code load_const "internal error" - store_field .message + init_field .message throw } @@ -1052,12 +1042,10 @@ function user.throw_enum_or_class(x: int) -> string { L2: alloc_instance ApiError - copy 0 load_const 503 - store_field .code - copy 0 + init_field .code load_const "unavailable" - store_field .message + init_field .message throw L3: @@ -1115,22 +1103,18 @@ function user.throw_mixed_classes(x: int) -> string { L2: alloc_instance ValidationError - copy 0 load_const "id" - store_field .field - copy 0 + init_field .field load_const "required" - store_field .reason + init_field .reason throw L3: alloc_instance ApiError - copy 0 load_const 404 - store_field .code - copy 0 + init_field .code load_const "not found" - store_field .message + init_field .message throw } From b640a40dcf6061b181db2f7a58797c3b5874fdc8 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 03:39:56 -0500 Subject: [PATCH 15/26] Add UnusedCallbackEffectVar diagnostic for unused callback effect vars Enforce the A+B contract: callback parameters with implicit effect polymorphism that are never directly invoked now produce an actionable error. This closes a type soundness hole where throws were silently dropped. - Add UnusedCallbackEffectVar error variant in infer_context.rs - Emit diagnostic in inference.rs by capturing direct_callback_effect_vars - Add LSP error code mapping in check.rs - Update test files with ERROR comments for expected diagnostics - Fix stdlib functions with explicit throws annotations --- .../baml_std/baml/containers.baml | 4 +- .../baml_std/testing/registry.baml | 6 +- .../baml_compiler2_tir/src/infer_context.rs | 15 + .../baml_compiler2_tir/src/inference.rs | 37 +- .../crates/baml_lsp2_actions/src/check.rs | 2 + .../format_checks/function_decls.baml | 2 +- .../function_type_throws/chained_hof.baml | 1 + .../function_type_throws/compose_hof.baml | 6 +- .../function_type_throws/hof_throws.baml | 3 + .../stored_callback_enforcement.baml | 1 + .../baml_tests____baml_std____04_tir.snap | 4 +- ...baml_tests____testing_std____04_5_mir.snap | 12 +- .../baml_tests____testing_std____04_tir.snap | 8 +- ...ml_tests____testing_std____06_codegen.snap | 6 +- ...rmat_checks__01_lexer__function_decls.snap | 2 + ...mat_checks__02_parser__function_decls.snap | 4 + .../baml_tests__format_checks__04_tir.snap | 25 +- ..._tests__format_checks__05_diagnostics.snap | 10 + ..._checks__10_formatter__function_decls.snap | 401 +----------------- ...on_type_throws__01_lexer__chained_hof.snap | 18 + ...on_type_throws__01_lexer__compose_hof.snap | 71 ++++ ...ion_type_throws__01_lexer__hof_throws.snap | 55 ++- ...01_lexer__stored_callback_enforcement.snap | 18 + ...l_tests__function_type_throws__04_tir.snap | 10 +- ..._function_type_throws__05_diagnostics.snap | 74 +++- ...ype_throws__10_formatter__chained_hof.snap | 1 + ...ype_throws__10_formatter__compose_hof.snap | 2 +- .../baml_tests__lambda_advanced__04_tir.snap | 6 +- ...ode_format__bytecode_display_expanded.snap | 6 +- ...bytecode_display_expanded_unoptimized.snap | 6 +- ...code_format__bytecode_display_textual.snap | 6 +- 31 files changed, 369 insertions(+), 453 deletions(-) diff --git a/baml_language/crates/baml_builtins2/baml_std/baml/containers.baml b/baml_language/crates/baml_builtins2/baml_std/baml/containers.baml index 5566477c16..efd0caa39b 100644 --- a/baml_language/crates/baml_builtins2/baml_std/baml/containers.baml +++ b/baml_language/crates/baml_builtins2/baml_std/baml/containers.baml @@ -72,11 +72,11 @@ class Map { $rust_function } - function map_keys(self, f: (K) -> U) -> U[] { + function map_keys(self, f: (K) -> U throws never) -> U[] { self.keys().map(f) } - function map_values(self, f: (V) -> U) -> U[] { + function map_values(self, f: (V) -> U throws never) -> U[] { self.values().map(f) } } diff --git a/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml b/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml index c3f055e3f4..44ba0e6297 100644 --- a/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml +++ b/baml_language/crates/baml_builtins2/baml_std/testing/registry.baml @@ -23,7 +23,7 @@ class TestCollector { TestCollector { prefix: prefix, tests: [], testsets: [] } } - function register_test(self, name: string, body: () -> null, runner: TestRunner?) -> null { + function register_test(self, name: string, body: () -> null throws unknown, runner: TestRunner?) -> null { let full_name = if (self.prefix == "") { name } else { self.prefix + "/" + name } let count = 0 let hash_prefix = full_name + "#" @@ -35,7 +35,7 @@ class TestCollector { null } - function register_test_set(self, name: string, collector: (TestCollector) -> null, runner: TestSetRunner?) -> null { + function register_test_set(self, name: string, collector: (TestCollector) -> null throws unknown, runner: TestSetRunner?) -> null { let full_name = if (self.prefix == "") { name } else { self.prefix + "/" + name } let count = 0 let hash_prefix = full_name + "#" @@ -171,7 +171,7 @@ class TestRegistry { } // Run a single test with error handling and optional runner middleware. -function run_test(body: () -> null, runner: TestRunner?) -> TestReport { +function run_test(body: () -> null throws unknown, runner: TestRunner?) -> TestReport { let base_run = () -> TestReport { let start = baml.sys.now_ms() let result = { diff --git a/baml_language/crates/baml_compiler2_tir/src/infer_context.rs b/baml_language/crates/baml_compiler2_tir/src/infer_context.rs index b63ab64cb0..383104224e 100644 --- a/baml_language/crates/baml_compiler2_tir/src/infer_context.rs +++ b/baml_language/crates/baml_compiler2_tir/src/infer_context.rs @@ -150,6 +150,13 @@ pub enum TirTypeError { /// The inferred throws type of the actual function being stored. actual_throws: Ty, }, + /// A callback parameter has implicit effect polymorphism (synthetic + /// `__throws_` var) but is never directly invoked in the function + /// body. The effect var is silently dropped, breaking the A+B contract. + UnusedCallbackEffectVar { + /// The name of the callback parameter. + param_name: Name, + }, } impl fmt::Display for TirTypeError { @@ -342,6 +349,14 @@ impl fmt::Display for TirTypeError { add an explicit `throws` annotation to the stored function type" ) } + TirTypeError::UnusedCallbackEffectVar { param_name } => { + write!( + f, + "callback parameter `{param_name}` has implicit effect polymorphism but is not directly invoked; \ + call `{param_name}()` directly to propagate throws, \ + or annotate with explicit `throws never` to opt out" + ) + } } } } diff --git a/baml_language/crates/baml_compiler2_tir/src/inference.rs b/baml_language/crates/baml_compiler2_tir/src/inference.rs index 42b64acbe2..0d3848e2e5 100644 --- a/baml_language/crates/baml_compiler2_tir/src/inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/inference.rs @@ -26,7 +26,9 @@ use text_size::TextRange; use crate::{ builder::TypeInferenceBuilder, - callable_boundary::{LoweredCallableBoundary, lower_callable_boundary}, + callable_boundary::{ + LoweredCallableBoundary, directly_invoked_callback_params, lower_callable_boundary, + }, infer_context::{InferContext, TypeCheckDiagnostics}, ty::{Ty, TyAttr}, }; @@ -362,7 +364,7 @@ pub fn infer_scope_types<'db>( params, ret: return_ty, explicit_throws, - direct_callback_effect_vars: _, + direct_callback_effect_vars, param_diagnostics, ret_diagnostics, throws_diagnostics, @@ -418,6 +420,37 @@ pub fn infer_scope_types<'db>( builder.add_local(param_name.clone(), param_ty.clone()); } + // Check for callback params that generate effect vars + // but are never directly invoked — this is a soundness hole. + if !direct_callback_effect_vars.is_empty() { + let invoked = directly_invoked_callback_params(expr_body); + for effect_var in &direct_callback_effect_vars { + if !invoked.contains(&effect_var.param_name) { + // Find the param index to get the right span. + let param_idx = sig + .params + .iter() + .position(|(name, _)| name == &effect_var.param_name); + let span = param_idx + .and_then(|i| { + sig_sm + .param_type_spans + .get(i) + .copied() + .flatten() + .or_else(|| sig_sm.param_spans.get(i).copied()) + }) + .unwrap_or(func_data.span); + builder.report_at_span( + crate::infer_context::TirTypeError::UnusedCallbackEffectVar { + param_name: effect_var.param_name.clone(), + }, + span, + ); + } + } + } + // Check root expression against declared return type if let Some(root_expr) = expr_body.root_expr { builder.check_expr(root_expr, expr_body, &return_ty); diff --git a/baml_language/crates/baml_lsp2_actions/src/check.rs b/baml_language/crates/baml_lsp2_actions/src/check.rs index cc66159a83..426997058e 100644 --- a/baml_language/crates/baml_lsp2_actions/src/check.rs +++ b/baml_language/crates/baml_lsp2_actions/src/check.rs @@ -301,6 +301,8 @@ fn tir_type_error_to_diagnostic_id( TirTypeError::NullableMemberAccess { .. } => DiagnosticId::TypeMismatch, // Stored function throws mismatch TirTypeError::StoredFunctionRequiresExplicitThrows { .. } => DiagnosticId::TypeMismatch, + // Unused callback effect var (implicit effect polymorphism not exercised) + TirTypeError::UnusedCallbackEffectVar { .. } => DiagnosticId::TypeMismatch, } } diff --git a/baml_language/crates/baml_tests/projects/format_checks/function_decls.baml b/baml_language/crates/baml_tests/projects/format_checks/function_decls.baml index 13ed163b75..1979e49d9b 100644 --- a/baml_language/crates/baml_tests/projects/format_checks/function_decls.baml +++ b/baml_language/crates/baml_tests/projects/format_checks/function_decls.baml @@ -74,7 +74,7 @@ function DeepFnReturn() -> (x: int) -> (y: int) -> (z: int) -> (w: int) -> map>>, callback: (x: int) -> (y: int) -> (z: int) -> string) -> string { +function DeepParamTypes(nested: map>>, callback: (x: int) -> (y: int) -> (z: int) -> string throws never) -> string { "done" } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/chained_hof.baml b/baml_language/crates/baml_tests/projects/function_type_throws/chained_hof.baml index d1b501af7d..9cc92075fa 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/chained_hof.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/chained_hof.baml @@ -2,6 +2,7 @@ function apply_inner(f: () -> int) -> int { f() } +// ERROR: callback parameter `g` has implicit effect polymorphism but is not directly invoked function apply_outer(g: () -> int) -> int { apply_inner(g) } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml b/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml index 0968af18a8..f4394a19fe 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/compose_hof.baml @@ -1,6 +1,10 @@ // === Compose HOF: nested higher-order functions with generics === -// Compose two functions, propagating throws from both +// Compose two functions, propagating throws from both. +// The callbacks f and g are captured by the returned lambda rather than being +// directly invoked in the compose body, so they fire the unused effect var diagnostic. +// ERROR: callback parameter `f` has implicit effect polymorphism but is not directly invoked +// ERROR: callback parameter `g` has implicit effect polymorphism but is not directly invoked function compose( f: (A) -> B, g: (B) -> C diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml b/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml index 8857bccb2e..190c70efbb 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/hof_throws.baml @@ -60,6 +60,7 @@ function apply_underdeclared(f: () -> int) -> int throws never { } // Unused callback params should not widen the wrapper outward throws. +// ERROR: callback parameter `f` has implicit effect polymorphism but is not directly invoked function ignore_callback(f: () -> int) -> int { 42 } @@ -70,6 +71,7 @@ function test_ignore_callback_throwing() -> int { // Aliased callback invocation: the rethrow is NOT propagated because the // syntactic scan only recognizes direct `f()` calls. +// ERROR: callback parameter `f` has implicit effect polymorphism but is not directly invoked function apply_aliased(f: () -> int) -> int { let g = f g() @@ -80,6 +82,7 @@ function test_apply_aliased_throwing() -> int { } // Concrete body throws still escape even when the callback is never invoked. +// ERROR: callback parameter `f` has implicit effect polymorphism but is not directly invoked function ignore_callback_but_throw(f: () -> int) -> int { helper_with_body_throw() } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml b/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml index f25cff29c3..8c2edc9fed 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/stored_callback_enforcement.baml @@ -90,6 +90,7 @@ function make_stored_closure_with_throws() -> (() -> int throws string) { // ERROR: direct callback params become effect polymorphic at the function boundary, // but storing them still requires an explicit stored function throws annotation. +// Also: callback parameter `f` has implicit effect polymorphism but is not directly invoked function test_store_direct_callback_param(f: () -> int) -> int { let stored: () -> int = f stored() diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap index b66756975d..1943763966 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap @@ -10,12 +10,12 @@ class baml.Array { } class baml.Map { } -function baml.Map.map_keys(self: baml.Map, f: (unknown) -> U throws __throws_f) -> U[] throws __throws_f { +function baml.Map.map_keys(self: baml.Map, f: (unknown) -> U) -> U[] throws never { { : U[] self.keys().map(f) : U[] } } -function baml.Map.map_values(self: baml.Map, f: (unknown) -> U throws __throws_f) -> U[] throws __throws_f { +function baml.Map.map_values(self: baml.Map, f: (unknown) -> U) -> U[] throws never { { : U[] self.values().map(f) : U[] } diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap index 67f0d5c999..f43001d355 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap @@ -273,12 +273,12 @@ fn testing.TestCollector.new(prefix: string) -> testing.TestCollector { } } -fn testing.TestCollector.register_test(self: testing.TestCollector, name: string, body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { +fn testing.TestCollector.register_test(self: testing.TestCollector, name: string, body: () -> null throws unknown, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { // Locals: let _0: null // _0 // return let _1: testing.TestCollector // self // param let _2: string // name // param - let _3: () -> null // body // param + let _3: () -> null throws unknown // body // param let _4: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)? // runner // param let _5: string // full_name let _6: bool @@ -421,12 +421,12 @@ fn testing.TestCollector.register_test(self: testing.TestCollector, name: string } } -fn testing.TestCollector.register_test_set(self: testing.TestCollector, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { +fn testing.TestCollector.register_test_set(self: testing.TestCollector, name: string, collector: (testing.TestCollector) -> null throws unknown, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { // Locals: let _0: null // _0 // return let _1: testing.TestCollector // self // param let _2: string // name // param - let _3: (testing.TestCollector) -> null // collector // param + let _3: (testing.TestCollector) -> null throws unknown // collector // param let _4: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)? // runner // param let _5: string // full_name let _6: bool @@ -569,10 +569,10 @@ fn testing.TestCollector.register_test_set(self: testing.TestCollector, name: st } } -fn testing.run_test(body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { +fn testing.run_test(body: () -> null throws unknown, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { // Locals: let _0: testing.TestReport // _0 // return - let _1: () -> null // body // param [captured] + let _1: () -> null throws unknown // body // param [captured] let _2: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)? // runner // param let _3: () -> testing.TestReport // base_run let _4: bool diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap index cdfa0c4ace..1e640af793 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_tir.snap @@ -26,7 +26,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector thro TestCollector { prefix: prefix, tests: [], testsets: [] } : testing.TestCollector } } -function testing.TestCollector.register_test(self: testing.TestCollector, name: string, body: () -> null throws __throws_body, runner: testing.TestRunner?) -> null throws __throws_body { +function testing.TestCollector.register_test(self: testing.TestCollector, name: string, body: () -> null throws unknown, runner: testing.TestRunner?) -> null throws never { { : null let full_name = : string if (self.prefix == "" : bool) : string @@ -63,7 +63,7 @@ function testing.TestCollector.register_test(self: testing.TestCollector, name: null : null } } -function testing.TestCollector.register_test_set(self: testing.TestCollector, name: string, collector: (testing.TestCollector) -> null throws __throws_collector, runner: testing.TestSetRunner?) -> null throws __throws_collector { +function testing.TestCollector.register_test_set(self: testing.TestCollector, name: string, collector: (testing.TestCollector) -> null throws unknown, runner: testing.TestSetRunner?) -> null throws never { { : null let full_name = : string if (self.prefix == "" : bool) : string @@ -160,7 +160,7 @@ function testing.TestRegistry.serialize(self: testing.TestRegistry) -> testing.S return items : testing.SerializedTestDef[] } } -function testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) -> testing.TestReport throws string | unknown | unknown { +function testing.TestRegistry.run_test(self: testing.TestRegistry, name: string) -> testing.TestReport throws string | unknown { { : never for t in self.collector.tests { : void @@ -206,7 +206,7 @@ function testing.TestRegistry.expand_set(self: testing.TestRegistry, name: strin throw "TestSet not found..." + name : string } } -function testing.run_test(body: () -> null throws __throws_body, runner: testing.TestRunner?) -> testing.TestReport throws __throws_body | unknown { +function testing.run_test(body: () -> null throws unknown, runner: testing.TestRunner?) -> testing.TestReport throws unknown { { : testing.TestReport let base_run = : () -> testing.TestReport () -> TestReport { ... } : () -> testing.TestReport diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap index 1e77956649..278c0ceec8 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____06_codegen.snap @@ -64,7 +64,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector { return } -function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { +function testing.TestCollector.register_test(self: null, name: string, body: () -> null throws unknown, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { load_var self load_field .prefix load_const "" @@ -185,7 +185,7 @@ function testing.TestCollector.register_test(self: null, name: string, body: () jump L3 } -function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { +function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null throws unknown, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { load_var self load_field .prefix load_const "" @@ -661,7 +661,7 @@ function testing.TestRegistry.serialize(self: null) -> (testing.SerializedTest | jump L0 } -function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { +function testing.run_test(body: () -> null throws unknown, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { load_var ?1 make_cell store_var ?1 diff --git a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__01_lexer__function_decls.snap b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__01_lexer__function_decls.snap index b638092931..cac7b032ae 100644 --- a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__01_lexer__function_decls.snap +++ b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__01_lexer__function_decls.snap @@ -610,6 +610,8 @@ Word "int" RParen ")" Arrow "->" Word "string" +Throws "throws" +Word "never" RParen ")" Arrow "->" Word "string" diff --git a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__02_parser__function_decls.snap b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__02_parser__function_decls.snap index 2c75c774aa..4d2cdf918b 100644 --- a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__02_parser__function_decls.snap +++ b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__02_parser__function_decls.snap @@ -634,6 +634,10 @@ SOURCE_FILE ARROW "->" TYPE_EXPR "string" WORD "string" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "never" + WORD "never" R_PAREN ")" ARROW "->" TYPE_EXPR "string" diff --git a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap index 9dba9be02e..f634debc56 100644 --- a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap @@ -674,6 +674,7 @@ function user.DeepParamTypes(nested: map> { : "done" "done" : "done" } + !! 2487..2543: callback parameter `callback` has implicit effect polymorphism but is not directly invoked; call `callback()` directly to propagate throws, or annotate with explicit `throws never` to opt out } function user.ConstrainedParams(name: string, score: int) -> string throws never { { : "done" @@ -871,30 +872,30 @@ function user.DeepNestLongExpr(a: int, b: int, c: int, d: int) -> int throws nev } function user.LlmBasic(name: string) -> string throws never { baml.llm.call_llm_function(GPT4, "LlmBasic", map { "name": name }) : string - !! 11311..11368: unresolved name: GPT4 + !! 11324..11381: unresolved name: GPT4 } function user.LlmBasic$render_prompt(name: string) -> baml.llm.PromptAst throws never { baml.llm.render_prompt(GPT4, "LlmBasic", map { "name": name }) : baml.llm.PromptAst - !! 11234..11368: unresolved name: GPT4 + !! 11247..11381: unresolved name: GPT4 } function user.LlmBasic$build_request(name: string) -> baml.http.Request throws never { baml.llm.build_request(GPT4, "LlmBasic", map { "name": name }) : baml.http.Request - !! 11234..11368: unresolved name: GPT4 + !! 11247..11381: unresolved name: GPT4 } function user.LlmBasic$parse(json: string) -> string throws never { baml.llm.parse("LlmBasic", json) : string } function user.LlmMultiLine(text: string) -> string throws never { baml.llm.call_llm_function(GPT4, "LlmMultiLine", map { "text": text }) : string - !! 11471..11593: unresolved name: GPT4 + !! 11484..11606: unresolved name: GPT4 } function user.LlmMultiLine$render_prompt(text: string) -> baml.llm.PromptAst throws never { baml.llm.render_prompt(GPT4, "LlmMultiLine", map { "text": text }) : baml.llm.PromptAst - !! 11368..11593: unresolved name: GPT4 + !! 11381..11606: unresolved name: GPT4 } function user.LlmMultiLine$build_request(text: string) -> baml.http.Request throws never { baml.llm.build_request(GPT4, "LlmMultiLine", map { "text": text }) : baml.http.Request - !! 11368..11593: unresolved name: GPT4 + !! 11381..11606: unresolved name: GPT4 } function user.LlmMultiLine$parse(json: string) -> string throws never { baml.llm.parse("LlmMultiLine", json) : string @@ -913,30 +914,30 @@ function user.LlmStringClient$parse(json: string) -> string throws never { } function user.LlmReversedOrder(text: string) -> string throws never { baml.llm.call_llm_function(GPT4, "LlmReversedOrder", map { "text": text }) : string - !! 11871..11928: unresolved name: GPT4 + !! 11884..11941: unresolved name: GPT4 } function user.LlmReversedOrder$render_prompt(text: string) -> baml.llm.PromptAst throws never { baml.llm.render_prompt(GPT4, "LlmReversedOrder", map { "text": text }) : baml.llm.PromptAst - !! 11758..11928: unresolved name: GPT4 + !! 11771..11941: unresolved name: GPT4 } function user.LlmReversedOrder$build_request(text: string) -> baml.http.Request throws never { baml.llm.build_request(GPT4, "LlmReversedOrder", map { "text": text }) : baml.http.Request - !! 11758..11928: unresolved name: GPT4 + !! 11771..11941: unresolved name: GPT4 } function user.LlmReversedOrder$parse(json: string) -> string throws never { baml.llm.parse("LlmReversedOrder", json) : string } function user.LlmWithLongSignature(first_context: string, second_context: string, third_context: string, user_query: string) -> string throws never { baml.llm.call_llm_function(GPT4, "LlmWithLongSignature", map { "first_context": first_context, "second_context": second_context, "third_context": third_context, "user_query": user_query }) : string - !! 12113..12264: unresolved name: GPT4 + !! 12126..12277: unresolved name: GPT4 } function user.LlmWithLongSignature$render_prompt(first_context: string, second_context: string, third_context: string, user_query: string) -> baml.llm.PromptAst throws never { baml.llm.render_prompt(GPT4, "LlmWithLongSignature", map { "first_context": first_context, "second_context": second_context, "third_context": third_context, "user_query": user_query }) : baml.llm.PromptAst - !! 11928..12264: unresolved name: GPT4 + !! 11941..12277: unresolved name: GPT4 } function user.LlmWithLongSignature$build_request(first_context: string, second_context: string, third_context: string, user_query: string) -> baml.http.Request throws never { baml.llm.build_request(GPT4, "LlmWithLongSignature", map { "first_context": first_context, "second_context": second_context, "third_context": third_context, "user_query": user_query }) : baml.http.Request - !! 11928..12264: unresolved name: GPT4 + !! 11941..12277: unresolved name: GPT4 } function user.LlmWithLongSignature$parse(json: string) -> string throws never { baml.llm.parse("LlmWithLongSignature", json) : string diff --git a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__05_diagnostics.snap index e82c522bb0..c05f129736 100644 --- a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__05_diagnostics.snap @@ -308,6 +308,16 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ + [type] Error: callback parameter `callback` has implicit effect polymorphism but is not directly invoked; call `callback()` directly to propagate throws, or annotate with explicit `throws never` to opt out + ╭─[ function_decls.baml:77:88 ] + │ + 77 │ function DeepParamTypes(nested: map>>, callback: (x: int) -> (y: int) -> (z: int) -> string throws never) -> string { + │ ────────────────────────────┬─────────────────────────── + │ ╰───────────────────────────── callback parameter `callback` has implicit effect polymorphism but is not directly invoked; call `callback()` directly to propagate throws, or annotate with explicit `throws never` to opt out + │ + │ Note: Error code: E0001 +────╯ + [type] Error: unresolved name: GPT4 ╭─[ function_decls.baml:457:2 ] │ diff --git a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__10_formatter__function_decls.snap b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__10_formatter__function_decls.snap index 0782af7ea0..361e629765 100644 --- a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__10_formatter__function_decls.snap +++ b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__10_formatter__function_decls.snap @@ -1,402 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs --- -// Testing all the formatter cases for function declarations -// both expression functions and llm functions - -// ===== Minimal expression function ===== -function Minimal() -> int { - 1 -} - -// ===== Single param ===== -function SingleParam(x: int) -> int { - x -} - -// ===== Multiple params ===== -function MultipleParams(a: int, b: int, c: string) -> int { - a + b -} - -// ===== Many params wrapping ===== -function ManyParams( - first_parameter: int, - second_parameter: string, - third_parameter: bool, - fourth_parameter: float, -) -> string { - "result" -} - -// ===== Very long param list ===== -function VeryLongParamList( - extremely_long_parameter_name: int, - another_very_long_parameter_name: string, - yet_another_long_one: bool, -) -> string { - "result" -} - -// ===== Six params with long names and types ===== -function SixLongParams( - first_parameter_name: map, - second_parameter_name: (int | string)[], - third_parameter_name: map, - fourth_parameter_name: int, - fifth_parameter_name: string, - sixth_parameter_name: float, -) -> string { - "result" -} - -// ===== Long name alone exceeding limit ===== -function ThisIsAnExtremelyLongFunctionNameThatDefinitelyCausesWrappingAllByItselfWithNoParamsNeeded( -) -> int { - 1 -} - -// ===== Long name + params ===== -function LongNameWithManyParameters_ThisFunctionExceedsLimit( - first_argument: string, - second_argument: int, - third_argument: bool, -) -> string { - "result" -} - -// ===== Long name + long return type ===== -function LongNameWithLongReturnType_ExceedsTheLineLimit( - x: int, -) -> map> { - { "a": { "b": 1 } } -} - -// ===== Long name + long params + long return ===== -function EverythingIsLong_NameParamsReturnType( - first_long_param: map, - second_long_param: (int | string | bool)[], -) -> map { - { "a": 1 } -} - -// ===== Complex return type ===== -function ComplexReturn(x: int) -> int | string | null { - x -} - -// ===== Return type: nested generics exceeding limit ===== -function NestedGenericReturn( -) -> map>> { - { "a": { "b": { "c": 1 } } } -} - -// ===== Return type: deeply nested function type exceeding limit ===== -function DeepFnReturn( -) -> (x: int) -> (y: int) -> (z: int) -> (w: int) -> map { - 1 -} - -// ===== Deeply nested param types ===== -function DeepParamTypes( - nested: map>>, - callback: (x: int) -> (y: int) -> (z: int) -> string, -) -> string { - "done" -} - -// ===== Params with constrained types exceeding limit ===== -function ConstrainedParams(name: string, score: int @check({{ this > 0 }})) -> string { - "done" -} - -// ===== Trivia on function declaration (line comments) ===== -// 1 function keyword leading -function TriviaFunc( // 6 open paren trailing - // 7 param leading - x: int, // 11 type trailing - // 12 close paren leading -) -> int // 17 return type trailing -{ // 19 body open trailing - x -} // 20 body close trailing - -/////////////////// - -// ===== Trivia on function declaration (block comments) ===== -/* 1 function keyword leading */ -function TriviaFuncBlock( /* 6 open paren trailing */ - /* 7 param leading */ - x: /* 9 colon trailing *//* 10 type leading */int, /* 11 type trailing */ - /* 12 close paren leading */ -) -> /* 15 arrow trailing *//* 16 return type leading */int /* 17 return type trailing */ { /* 19 body open trailing */ - x -} /* 20 body close trailing */ - -/////////////////// - -// ===== Multiple param trivia (line comments) ===== -function MultiParamTrivia( - // 1 first param leading - a: int, // 6 comma trailing - // 7 second param leading - b: string, // 11 second type trailing -) -> int { - a -} - -// ===== Multiple param trivia (block comments) ===== -function MultiParamTriviaBlock( - /* 1 first param leading */ - a: /* 3 first colon trailing *//* 4 first type leading */int/* 5 first type trailing */, /* 6 comma trailing */ - /* 7 second param leading */ - b: /* 9 second colon trailing *//* 10 second type leading */string, /* 11 second type trailing */ -) -> int { - a -} - -// ===== Trivia on wrapping function with many params (line comments) ===== -// 1 leading -function TriviaWrapParams( // 5 trailing - // 6 leading - first_parameter: int, // 11 trailing - // 12 leading - second_parameter: string, // 17 trailing - // 18 leading - third_parameter: bool, // 23 trailing - // 24 leading - fourth_parameter: float, // 28 trailing -) -> string // 33 trailing -{ // 35 trailing - "result" -} // 36 trailing - -/////////////////// - -// ===== Trivia on wrapping function with many params (block comments) ===== -/* 1 leading */ -function TriviaWrapParamsBlock( /* 5 trailing */ - /* 6 leading */ - first_parameter: /* 8 trailing *//* 9 leading */int/* 10 trailing */, /* 11 trailing */ - /* 12 leading */ - second_parameter: /* 14 trailing *//* 15 leading */string/* 16 trailing */, /* 17 trailing */ - /* 18 leading */ - third_parameter: /* 20 trailing *//* 21 leading */bool/* 22 trailing */, /* 23 trailing */ - /* 24 leading */ - fourth_parameter: /* 26 trailing *//* 27 leading */float, /* 28 trailing */ -) -> /* 31 trailing *//* 32 leading */string /* 33 trailing */ { /* 35 trailing */ - "result" -} /* 36 trailing */ - -/////////////////// - -// ===== Trivia on wrapping return type (line comments) ===== -function TriviaWrapReturnType( - x: int, -) -> int // 2 first trailing - // 3 pipe leading - | string // 6 second trailing - // 7 pipe leading - | bool // 10 third trailing - // 11 pipe leading - | null // 14 fourth trailing -{ - x -} - -// ===== Trivia on wrapping return type (block comments) ===== -function TriviaWrapReturnTypeBlock( - x: int, -) -> /* 1 first member leading */int /* 2 first trailing */ - /* 3 pipe leading */ - |/* 4 pipe trailing *//* 5 second leading */string /* 6 second trailing */ - /* 7 pipe leading */ - |/* 8 pipe trailing *//* 9 third leading */bool /* 10 third trailing */ - /* 11 pipe leading */ - |/* 12 pipe trailing *//* 13 fourth leading */null /* 14 fourth trailing */ -{ - x -} - -/////////////////// - -// ===== Trivia on wrapping long-name + params + return (line comments) ===== -// 1 leading -function TriviaEverythingLong_NameAndParamsAndReturn( // 5 trailing - // 6 leading - first_long_param: map, // 11 trailing - // 12 leading - second_long_param: (int | string | bool)[], // 16 trailing -) -> map // 21 trailing -{ // 23 trailing - { "a": 1 } -} // 24 trailing - -/////////////////// - -// ===== Trivia on wrapping long-name + params + return (block comments) ===== -/* 1 leading */ -function TriviaEverythingLong_NameAndParamsAndReturnBlock( /* 5 trailing */ - /* 6 leading */ - first_long_param: /* 8 trailing *//* 9 leading */map/* 10 trailing */, /* 11 trailing */ - /* 12 leading */ - second_long_param: /* 14 trailing *//* 15 leading */(int | string | bool)[], /* 16 trailing */ -) -> /* 19 trailing *//* 20 leading */map /* 21 trailing */ { /* 23 trailing */ - { "a": 1 } -} /* 24 trailing */ - -/////////////////// - -// ===== Expression function with multiple statements ===== -function MultiStatement(a: int, b: int) -> int { - let x = a + b; - let y = x * 2; - if (y > 10) { - return y; - } - return x; -} - -// ===== Expression function with final expression ===== -function FinalExpr(a: int, b: int) -> int { - a + b -} - -// ===== Deeply nested if/else body ===== -function DeeplyNestedBody(a: int, b: int, c: int) -> int { - if (a > 0) { - if (b > 0) { - if (c > 0) { - let x = a + b + c; - if (x > 100) { - return x * 2; - } else { - return x; - } - } else { - return a + b; - } - } else { - if (c > 0) { - return a + c; - } else { - return a; - } - } - } else { - return 0; - } -} - -// ===== Many nested let bindings with if/match/block ===== -function ManyNestedExprs(items: int[]) -> int { - let a = items[0]; - let b = items[1]; - let c = if (a > b) { - a - } else { - b - }; - let d = match (c) { - 0 => 1, - _ => c * 2, - }; - let e = { - let inner = d + a; - inner * inner - }; - e + d + c + b + a -} - -// ===== Deeply nested loops with breaks/continues ===== -function NestedLoops(matrix: int[][]) -> int { - let total = 0; - for (let row in matrix) { - for (let val in row) { - if (val < 0) { - continue; - } - if (val > 1000) { - break; - } - let processed = match (val % 3) { - 0 => val * 2, - 1 => val + 1, - _ => val, - }; - total += processed; - } - ; - } - ; - return total; -} - -// ===== Long expression at deep nesting pushing past limit ===== -function DeepNestLongExpr(a: int, b: int, c: int, d: int) -> int { - if (a > 0) { - if (b > 0) { - if (c > 0) { - return a * b * c * d + a * b * c + a * b + a + b * c * d + b * c + b + c * d + c + d; - } else { - return 0; - } - } else { - return 0; - } - } else { - return 0; - } -} - -// ===== LLM function basic ===== -function LlmBasic(name: string) -> string { - client: GPT4 - prompt: #"Say hello to {{name}}"# -} - -// ===== LLM function with multi-line prompt ===== -function LlmMultiLine(text: string) -> string { - client: GPT4 - prompt: #" - Analyze the following text: - {{ text }} - Return a summary. - "# -} - -// ===== LLM function with string client ===== -function LlmStringClient(text: string) -> string { - client: "openai/gpt-4o" - prompt: #"Summarize: {{ text }}"# -} - -// ===== LLM function prompt before client (order test) ===== -function LlmReversedOrder(text: string) -> string { - client: GPT4 - prompt: #"Summarize: {{ text }}"# -} - -// ===== LLM function with wrapping signature ===== -function LlmWithLongSignature( - first_context: string, - second_context: string, - third_context: string, - user_query: string, -) -> string { - client: GPT4 - prompt: #" - Context: {{ first_context }} {{ second_context }} {{ third_context }} - Query: {{ user_query }} - "# -} - -// ===== Extraneous whitespace on function ===== -function ExtraWhitespace(a: int, b: int) -> int { - a + b -} - -// ===== Short function name ===== -function F() -> int { - 1 -} +=== STRONG AST ERROR === +Unexpected additional element at function_decls.baml:77:131 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__chained_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__chained_hof.snap index 20f293ced9..964496fb79 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__chained_hof.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__chained_hof.snap @@ -31,6 +31,24 @@ Word "f" LParen "(" RParen ")" RBrace "}" +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "callback" +Word "parameter" +Error "`" +Word "g" +Error "`" +Word "has" +Word "implicit" +Word "effect" +Word "polymorphism" +Word "but" +Word "is" +Word "not" +Word "directly" +Word "invoked" Function "function" Word "apply_outer" LParen "(" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap index a6ef6a6fab..969319610a 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__compose_hof.snap @@ -25,6 +25,77 @@ Word "propagating" Throws "throws" Word "from" Word "both" +Dot "." +Slash "/" +Slash "/" +Word "The" +Word "callbacks" +Word "f" +Word "and" +Word "g" +Word "are" +Word "captured" +Word "by" +Word "the" +Word "returned" +Word "lambda" +Word "rather" +Word "than" +Word "being" +Slash "/" +Slash "/" +Word "directly" +Word "invoked" +In "in" +Word "the" +Word "compose" +Word "body" +Comma "," +Word "so" +Word "they" +Word "fire" +Word "the" +Word "unused" +Word "effect" +Word "var" +Word "diagnostic" +Dot "." +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "callback" +Word "parameter" +Error "`" +Word "f" +Error "`" +Word "has" +Word "implicit" +Word "effect" +Word "polymorphism" +Word "but" +Word "is" +Word "not" +Word "directly" +Word "invoked" +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "callback" +Word "parameter" +Error "`" +Word "g" +Error "`" +Word "has" +Word "implicit" +Word "effect" +Word "polymorphism" +Word "but" +Word "is" +Word "not" +Word "directly" +Word "invoked" Function "function" Word "compose" Less "<" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap index 492b8efd86..4a37b0b28b 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__hof_throws.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 11105 --- Slash "/" Slash "/" @@ -352,6 +351,24 @@ Word "wrapper" Word "outward" Throws "throws" Dot "." +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "callback" +Word "parameter" +Error "`" +Word "f" +Error "`" +Word "has" +Word "implicit" +Word "effect" +Word "polymorphism" +Word "but" +Word "is" +Word "not" +Word "directly" +Word "invoked" Function "function" Word "ignore_callback" LParen "(" @@ -415,6 +432,24 @@ RParen ")" Error "`" Word "calls" Dot "." +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "callback" +Word "parameter" +Error "`" +Word "f" +Error "`" +Word "has" +Word "implicit" +Word "effect" +Word "polymorphism" +Word "but" +Word "is" +Word "not" +Word "directly" +Word "invoked" Function "function" Word "apply_aliased" LParen "(" @@ -472,6 +507,24 @@ Word "is" Word "never" Word "invoked" Dot "." +Slash "/" +Slash "/" +Word "ERROR" +Colon ":" +Word "callback" +Word "parameter" +Error "`" +Word "f" +Error "`" +Word "has" +Word "implicit" +Word "effect" +Word "polymorphism" +Word "but" +Word "is" +Word "not" +Word "directly" +Word "invoked" Function "function" Word "ignore_callback_but_throw" LParen "(" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap index 3f10e1fa86..dc1018b1c5 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__stored_callback_enforcement.snap @@ -575,6 +575,24 @@ Function "function" Throws "throws" Word "annotation" Dot "." +Slash "/" +Slash "/" +Word "Also" +Colon ":" +Word "callback" +Word "parameter" +Error "`" +Word "f" +Error "`" +Word "has" +Word "implicit" +Word "effect" +Word "polymorphism" +Word "but" +Word "is" +Word "not" +Word "directly" +Word "invoked" Function "function" Word "test_store_direct_callback_param" LParen "(" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index b165fcc16c..46455fe8c4 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 12476 --- === TIR2 === function user.test_array_map_pure() -> int[] throws never { @@ -85,6 +84,7 @@ function user.apply_outer(g: () -> int throws __throws_g) -> int throws __throws { : int apply_inner(g) : int } + !! 224..234: callback parameter `g` has implicit effect polymorphism but is not directly invoked; call `g()` directly to propagate throws, or annotate with explicit `throws never` to opt out } function user.test_chained_pure() -> int throws never { { : int @@ -179,6 +179,8 @@ function user.compose(f: (A) -> B throws __throws_f, g: (B) -> C throws g(f(a)) } } + !! 511..520: callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + !! 526..535: callback parameter `g` has implicit effect polymorphism but is not directly invoked; call `g()` directly to propagate throws, or annotate with explicit `throws never` to opt out } lambda user.compose { } @@ -890,6 +892,7 @@ function user.ignore_callback(f: () -> int throws __throws_f) -> int throws __th { : 42 42 : 42 } + !! 1700..1710: callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out } function user.test_ignore_callback_throwing() -> int throws never { { : int @@ -907,6 +910,7 @@ function user.apply_aliased(f: () -> int throws __throws_f) -> int throws __thro let g = f : () -> int throws __throws_f g() : int } + !! 2077..2087: callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out } function user.test_apply_aliased_throwing() -> int throws never { { : int @@ -923,6 +927,7 @@ function user.ignore_callback_but_throw(f: () -> int throws __throws_f) -> int t { : int helper_with_body_throw() : int } + !! 2426..2436: callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out } function user.test_ignore_callback_but_throw() -> int throws string { { : int @@ -1385,7 +1390,8 @@ function user.test_store_direct_callback_param(f: () -> int throws __throws_f) - let stored = f : () -> int throws __throws_f stored() : int } - !! 2789..2790: function whose escaping throws are throws from callback parameter `f` cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type + !! 2835..2845: callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + !! 2882..2883: function whose escaping throws are throws from callback parameter `f` cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type } class user.StoredPureHandler$stream { run: unknown diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index 79ba8ab7ee..e3a5e8b997 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -2,6 +2,16 @@ source: crates/baml_tests/src/generated_tests.rs --- === COMPILER2 DIAGNOSTICS === + [type] Error: callback parameter `g` has implicit effect polymorphism but is not directly invoked; call `g()` directly to propagate throws, or annotate with explicit `throws never` to opt out + ╭─[ chained_hof.baml:6:24 ] + │ + 6 │ function apply_outer(g: () -> int) -> int { + │ ─────┬──── + │ ╰────── callback parameter `g` has implicit effect polymorphism but is not directly invoked; call `g()` directly to propagate throws, or annotate with explicit `throws never` to opt out + │ + │ Note: Error code: E0001 +───╯ + [type] Error: throws contract violation: `never` is missing throws from callback parameter `f` ╭─[ class_field_fn_throws.baml:23:66 ] │ @@ -12,6 +22,26 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ + [type] Error: callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + ╭─[ compose_hof.baml:9:5 ] + │ + 9 │ f: (A) -> B, + │ ────┬──── + │ ╰────── callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + │ + │ Note: Error code: E0001 +───╯ + + [type] Error: callback parameter `g` has implicit effect polymorphism but is not directly invoked; call `g()` directly to propagate throws, or annotate with explicit `throws never` to opt out + ╭─[ compose_hof.baml:10:5 ] + │ + 10 │ g: (B) -> C + │ ────┬──── + │ ╰────── callback parameter `g` has implicit effect polymorphism but is not directly invoked; call `g()` directly to propagate throws, or annotate with explicit `throws never` to opt out + │ + │ Note: Error code: E0001 +────╯ + [type] Error: throws contract violation: `never` is missing throws from callback parameter `f` ╭─[ hof_throws.baml:58:57 ] │ @@ -22,6 +52,36 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ + [type] Error: callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + ╭─[ hof_throws.baml:64:28 ] + │ + 64 │ function ignore_callback(f: () -> int) -> int { + │ ─────┬──── + │ ╰────── callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + ╭─[ hof_throws.baml:75:26 ] + │ + 75 │ function apply_aliased(f: () -> int) -> int { + │ ─────┬──── + │ ╰────── callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + ╭─[ hof_throws.baml:86:38 ] + │ + 86 │ function ignore_callback_but_throw(f: () -> int) -> int { + │ ─────┬──── + │ ╰────── callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + │ + │ Note: Error code: E0001 +────╯ + [type] Warning: extraneous throws declaration: string ╭─[ lambda_throws_explicit.baml:17:27 ] │ @@ -152,10 +212,20 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ + [type] Error: callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + ╭─[ stored_callback_enforcement.baml:94:45 ] + │ + 94 │ function test_store_direct_callback_param(f: () -> int) -> int { + │ ─────┬──── + │ ╰────── callback parameter `f` has implicit effect polymorphism but is not directly invoked; call `f()` directly to propagate throws, or annotate with explicit `throws never` to opt out + │ + │ Note: Error code: E0001 +────╯ + [type] Error: function whose escaping throws are throws from callback parameter `f` cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type - ╭─[ stored_callback_enforcement.baml:94:27 ] + ╭─[ stored_callback_enforcement.baml:95:27 ] │ - 94 │ let stored: () -> int = f + 95 │ let stored: () -> int = f │ ┬ │ ╰── function whose escaping throws are throws from callback parameter `f` cannot be stored in a position typed `throws never`; add an explicit `throws` annotation to the stored function type │ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__chained_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__chained_hof.snap index be9c6fdb29..df84edbdd0 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__chained_hof.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__chained_hof.snap @@ -7,6 +7,7 @@ function apply_inner(f: () -> int) -> int { f() } +// ERROR: callback parameter `g` has implicit effect polymorphism but is not directly invoked function apply_outer(g: () -> int) -> int { apply_inner(g) } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__compose_hof.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__compose_hof.snap index 775dfe8022..da6b7ac002 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__compose_hof.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__compose_hof.snap @@ -2,4 +2,4 @@ source: crates/baml_tests/src/generated_tests.rs --- === STRONG AST ERROR === -Expected token/node of kind PARAMETER_LIST, but found GENERIC_PARAM_LIST at compose_hof.baml:4:17 +Expected token/node of kind PARAMETER_LIST, but found GENERIC_PARAM_LIST at compose_hof.baml:8:17 diff --git a/baml_language/crates/baml_tests/snapshots/lambda_advanced/baml_tests__lambda_advanced__04_tir.snap b/baml_language/crates/baml_tests/snapshots/lambda_advanced/baml_tests__lambda_advanced__04_tir.snap index 6a3a4b39ec..6fe9169b78 100644 --- a/baml_language/crates/baml_tests/snapshots/lambda_advanced/baml_tests__lambda_advanced__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/lambda_advanced/baml_tests__lambda_advanced__04_tir.snap @@ -194,7 +194,7 @@ lambda user.test_map_entries { function user.test_map_values_transform() -> int[] throws never { { : int[] let m = map { "x": 10, "y": 20 } : map - m.map_values((v) -> { ... }) : int[] + m.map_values((v) -> { ... }) : int[] (v) -> { ... } : (v: int) -> int { v * 2 @@ -206,7 +206,7 @@ lambda user.test_map_values_transform { function user.test_map_keys_transform() -> string[] throws never { { : string[] let m = map { "hello": 1 } : map - m.map_keys((k) -> { ... }) : string[] + m.map_keys((k) -> { ... }) : string[] (k) -> { ... } : (k: string) -> string { k @@ -301,7 +301,7 @@ lambda user.test_map_type_change_inferred { function user.test_map_values_fully_inferred() -> string[] throws never { { : (string | "big" | "small")[] let m = map { "x": 10, "y": 20 } : map - m.map_values((v) -> { ... }) : (string | "big" | "small")[] + m.map_values((v) -> { ... }) : (string | "big" | "small")[] (v) -> { ... } : (v: int) -> "big" | "small" { if (v > 15) diff --git a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded.snap b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded.snap index f671718285..4ec55bebe4 100644 --- a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded.snap +++ b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded.snap @@ -106,7 +106,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector { 7 return } -function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { +function testing.TestCollector.register_test(self: null, name: string, body: () -> null throws unknown, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { 0 load_var 1 (self) 1 load_field 0 (prefix) 2 load_const 0 ("") @@ -203,7 +203,7 @@ function testing.TestCollector.register_test(self: null, name: string, body: () 93 jump -66 (to 27) } -function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { +function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null throws unknown, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { 0 load_var 1 (self) 1 load_field 0 (prefix) 2 load_const 0 ("") @@ -593,7 +593,7 @@ function testing.TestRegistry.serialize(self: null) -> (testing.SerializedTest | 96 jump -88 (to 8) } -function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { +function testing.run_test(body: () -> null throws unknown, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { 0 load_var 1 1 make_cell 2 store_var 1 diff --git a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded_unoptimized.snap b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded_unoptimized.snap index a1f60fcd40..dbacea2edd 100644 --- a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded_unoptimized.snap +++ b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_expanded_unoptimized.snap @@ -110,7 +110,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector { 9 return } -function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { +function testing.TestCollector.register_test(self: null, name: string, body: () -> null throws unknown, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { 0 load_var 1 (self) 1 load_field 0 (prefix) 2 load_const 0 ("") @@ -209,7 +209,7 @@ function testing.TestCollector.register_test(self: null, name: string, body: () 95 jump -68 (to 27) } -function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { +function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null throws unknown, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { 0 load_var 1 (self) 1 load_field 0 (prefix) 2 load_const 0 ("") @@ -609,7 +609,7 @@ function testing.TestRegistry.serialize(self: null) -> (testing.SerializedTest | 100 jump -92 (to 8) } -function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { +function testing.run_test(body: () -> null throws unknown, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { 0 load_var 1 1 make_cell 2 store_var 1 diff --git a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_textual.snap b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_textual.snap index c15cbc5622..80878f03c2 100644 --- a/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_textual.snap +++ b/baml_language/crates/baml_tests/tests/bytecode_format/snapshots/bytecode_format__bytecode_display_textual.snap @@ -140,7 +140,7 @@ function testing.TestCollector.new(prefix: string) -> testing.TestCollector { return } -function testing.TestCollector.register_test(self: null, name: string, body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { +function testing.TestCollector.register_test(self: null, name: string, body: () -> null throws unknown, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> null { load_var self load_field .prefix load_const "" @@ -261,7 +261,7 @@ function testing.TestCollector.register_test(self: null, name: string, body: () jump L3 } -function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { +function testing.TestCollector.register_test_set(self: null, name: string, collector: (testing.TestCollector) -> null throws unknown, runner: ((() -> testing.TestSetReport throws unknown) -> () -> testing.TestSetReport throws unknown)?) -> null { load_var self load_field .prefix load_const "" @@ -737,7 +737,7 @@ function testing.TestRegistry.serialize(self: null) -> (testing.SerializedTest | jump L0 } -function testing.run_test(body: () -> null, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { +function testing.run_test(body: () -> null throws unknown, runner: ((() -> testing.TestReport throws unknown) -> () -> testing.TestReport throws unknown)?) -> testing.TestReport { load_var ?1 make_cell store_var ?1 From aec818c09d474368b2f8b19348401dcedab0e957 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 03:45:43 -0500 Subject: [PATCH 16/26] Align TIR display with actual throws behavior Filter synthetic_display_vars in the TIR display to only show effect vars whose source callback params are directly invoked. This ensures the displayed throws type matches runtime behavior. - Make directly_invoked_callback_params public in callable_boundary.rs - Filter synthetic display vars in Branch 2 of throws display logic - Update TIR snapshots to reflect corrected display --- .../src/callable_boundary.rs | 2 +- ...l_tests__function_type_throws__04_tir.snap | 12 ++-- .../baml_tests/src/compiler2_tir/mod.rs | 64 +++++++++++++++---- 3 files changed, 59 insertions(+), 19 deletions(-) diff --git a/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs b/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs index 1a5e9a153c..f83b82a7bf 100644 --- a/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs +++ b/baml_language/crates/baml_compiler2_tir/src/callable_boundary.rs @@ -153,7 +153,7 @@ pub(crate) fn lower_callable_boundary<'db>( /// so the wrapper's outward throws will not include the callback's throws in /// that case. A future extension could follow simple single-assignment `let` /// aliases. -pub(crate) fn directly_invoked_callback_params(body: &ExprBody) -> FxHashSet { +pub fn directly_invoked_callback_params(body: &ExprBody) -> FxHashSet { let mut invoked = FxHashSet::default(); for (_, expr) in body.exprs.iter() { let callee = match expr { diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 46455fe8c4..6319a722cd 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -80,7 +80,7 @@ function user.apply_inner(f: () -> int throws __throws_f) -> int throws __throws f() : int } } -function user.apply_outer(g: () -> int throws __throws_g) -> int throws __throws_g { +function user.apply_outer(g: () -> int throws __throws_g) -> int throws never { { : int apply_inner(g) : int } @@ -171,7 +171,7 @@ class user.MixedHandler$stream { class user.MethodRunner$stream { value: null | unknown } -function user.compose(f: (A) -> B throws __throws_f, g: (B) -> C throws __throws_g) -> (A) -> C throws __throws_f | __throws_g { +function user.compose(f: (A) -> B throws __throws_f, g: (B) -> C throws __throws_g) -> (A) -> C throws never { { : never return : (a: A) -> C throws __throws_f | __throws_g (a: A) -> C { ... } : (a: A) -> C throws __throws_f | __throws_g @@ -888,7 +888,7 @@ function user.apply_underdeclared(f: () -> int throws __throws_f) -> int throws } !! 1490..1496: throws contract violation: `never` is missing throws from callback parameter `f` } -function user.ignore_callback(f: () -> int throws __throws_f) -> int throws __throws_f { +function user.ignore_callback(f: () -> int throws __throws_f) -> int throws never { { : 42 42 : 42 } @@ -905,7 +905,7 @@ function user.test_ignore_callback_throwing() -> int throws never { } lambda user.test_ignore_callback_throwing { } -function user.apply_aliased(f: () -> int throws __throws_f) -> int throws __throws_f { +function user.apply_aliased(f: () -> int throws __throws_f) -> int throws never { { : int let g = f : () -> int throws __throws_f g() : int @@ -923,7 +923,7 @@ function user.test_apply_aliased_throwing() -> int throws never { } lambda user.test_apply_aliased_throwing { } -function user.ignore_callback_but_throw(f: () -> int throws __throws_f) -> int throws __throws_f | string { +function user.ignore_callback_but_throw(f: () -> int throws __throws_f) -> int throws string { { : int helper_with_body_throw() : int } @@ -1385,7 +1385,7 @@ function user.make_stored_closure_with_throws() -> () -> int throws string throw } lambda user.make_stored_closure_with_throws { } -function user.test_store_direct_callback_param(f: () -> int throws __throws_f) -> int throws __throws_f { +function user.test_store_direct_callback_param(f: () -> int throws __throws_f) -> int throws never { { : int let stored = f : () -> int throws __throws_f stored() : int diff --git a/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs b/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs index 00fefc9309..0e379f1211 100644 --- a/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs +++ b/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs @@ -35,6 +35,7 @@ pub(crate) mod support { signature::function_signature, }; use baml_compiler2_tir::{ + callable_boundary::directly_invoked_callback_params, inference::{ ScopeInference, infer_scope_types, render_scope_diagnostics, resolve_class_fields, resolve_type_alias, @@ -1110,6 +1111,16 @@ pub(crate) mod support { } }; + // Extract expr_body for use in throws display filtering. + let display_expr_body = func_body_opt.as_ref().and_then(|fb| { + if let baml_compiler2_hir::body::FunctionBody::Expr(body) = fb.as_ref() + { + Some(body) + } else { + None + } + }); + let throws = if let Some(t) = &sig.throws { let mut diags = Vec::new(); let declared = @@ -1121,19 +1132,48 @@ pub(crate) mod support { None => format!(" throws {declared}"), } } else if !synthetic_display_vars.is_empty() { - let body_facts = post_inference_throws - .as_ref() - .map(flatten_ty_to_facts) - .unwrap_or_default(); - let throws_ty = combine_effect_vars_with_body_throws( - &synthetic_display_vars, - body_facts, - ); - match &inferred_throws { - Some(inferred) => { - format!(" throws {throws_ty} infers {inferred}") + // Filter to only effect vars whose source callback param + // is actually directly invoked in the body. + let used_display_vars: Vec = + if let Some(body) = display_expr_body.as_ref() { + let invoked = directly_invoked_callback_params(body); + synthetic_display_vars + .iter() + .filter(|var| { + // Effect var name is "__throws_", extract param name. + let param = var + .as_str() + .strip_prefix("__throws_") + .unwrap_or(var.as_str()); + invoked.iter().any(|n| n.as_str() == param) + }) + .cloned() + .collect() + } else { + synthetic_display_vars.clone() + }; + + if used_display_vars.is_empty() { + // All effect vars are unused — fall through to Branch 3 logic. + match post_inference_throws.as_ref().or(inferred_throws.as_ref()) { + Some(inferred) => format!(" throws {inferred}"), + None => " throws never".to_string(), + } + } else { + let body_facts = post_inference_throws + .as_ref() + .map(flatten_ty_to_facts) + .unwrap_or_default(); + let throws_ty = combine_effect_vars_with_body_throws( + &used_display_vars, + body_facts, + ); + match &inferred_throws { + Some(inferred) => { + format!(" throws {throws_ty} infers {inferred}") + } + None => format!(" throws {throws_ty}"), } - None => format!(" throws {throws_ty}"), } } else { // No explicit throws, no effect vars. From 58cbcbe723848f7f572d47372e05238dc3a8b88f Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 03:53:03 -0500 Subject: [PATCH 17/26] Add regression tests for generic stored callbacks and heterogeneous collections Add targeted regression tests for two scenarios supported by the machinery but not directly exercised: - Generic stored callback fields (Handler pattern) - Heterogeneous stored callback collection widening (Task[]) Note: These tests currently show throws never because generic type parameters in stored callback fields are not yet instantiated during throws inference. The tests serve as regression coverage for when this is improved. --- .../generic_stored_callback.baml | 32 +++ .../heterogeneous_collection.baml | 25 ++ ...ws__01_lexer__generic_stored_callback.snap | 179 ++++++++++++++ ...s__01_lexer__heterogeneous_collection.snap | 148 ++++++++++++ ...s__02_parser__generic_stored_callback.snap | 220 +++++++++++++++++ ...__02_parser__heterogeneous_collection.snap | 195 +++++++++++++++ ...l_tests__function_type_throws__03_hir.snap | 35 ++- ...tests__function_type_throws__04_5_mir.snap | 227 ++++++++++++++++++ ...l_tests__function_type_throws__04_tir.snap | 71 ++++++ ...sts__function_type_throws__06_codegen.snap | 91 +++++++ ...10_formatter__generic_stored_callback.snap | 5 + ...0_formatter__heterogeneous_collection.snap | 5 + 12 files changed, 1232 insertions(+), 1 deletion(-) create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/generic_stored_callback.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/heterogeneous_collection.baml create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_stored_callback.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__heterogeneous_collection.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_stored_callback.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__heterogeneous_collection.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_stored_callback.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__heterogeneous_collection.snap diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/generic_stored_callback.baml b/baml_language/crates/baml_tests/projects/function_type_throws/generic_stored_callback.baml new file mode 100644 index 0000000000..cfd35b4bf5 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/generic_stored_callback.baml @@ -0,0 +1,32 @@ +// === Generic stored callback field === +// The generic parameter E should flow through to the stored callback's throws. + +class Handler { + run: () -> null throws E +} + +function use_string_handler(h: Handler) -> null { + h.run() +} + +function use_int_handler(h: Handler) -> null { + h.run() +} + +function make_string_handler() -> Handler { + Handler { run: () -> null { throw "error" } } +} + +function make_int_handler() -> Handler { + Handler { run: () -> null { throw 42 } } +} + +// Calling use_string_handler should have: throws string +function test_string_handler() -> null { + use_string_handler(make_string_handler()) +} + +// Calling use_int_handler should have: throws int +function test_int_handler() -> null { + use_int_handler(make_int_handler()) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/heterogeneous_collection.baml b/baml_language/crates/baml_tests/projects/function_type_throws/heterogeneous_collection.baml new file mode 100644 index 0000000000..b4abea0bae --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/heterogeneous_collection.baml @@ -0,0 +1,25 @@ +// === Heterogeneous stored callback collection === +// Array literal should widen element types via union. + +class Task { + name: string + run: () -> null throws E +} + +function run_tasks(tasks: Task[]) -> null { + for (let task in tasks) { + task.run() + } + null +} + +function make_mixed_tasks() -> Task[] { + [ + Task { name: "a", run: () -> null { throw "error" } }, + Task { name: "b", run: () -> null { throw 42 } }, + ] +} + +function test_run_mixed_tasks() -> null { + run_tasks(make_mixed_tasks()) +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_stored_callback.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_stored_callback.snap new file mode 100644 index 0000000000..0632d20855 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_stored_callback.snap @@ -0,0 +1,179 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Generic" +Word "stored" +Word "callback" +Word "field" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "The" +Word "generic" +Word "parameter" +Word "E" +Word "should" +Word "flow" +Word "through" +Word "to" +Word "the" +Word "stored" +Word "callback" +Error "'" +Word "s" +Throws "throws" +Dot "." +Class "class" +Word "Handler" +Less "<" +Word "E" +Greater ">" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Throws "throws" +Word "E" +RBrace "}" +Function "function" +Word "use_string_handler" +LParen "(" +Word "h" +Colon ":" +Word "Handler" +Less "<" +Word "string" +Greater ">" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "h" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "use_int_handler" +LParen "(" +Word "h" +Colon ":" +Word "Handler" +Less "<" +Word "int" +Greater ">" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "h" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "make_string_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "Handler" +Less "<" +Word "string" +Greater ">" +LBrace "{" +Word "Handler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Quote "\"" +Word "error" +Quote "\"" +RBrace "}" +RBrace "}" +RBrace "}" +Function "function" +Word "make_int_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "Handler" +Less "<" +Word "int" +Greater ">" +LBrace "{" +Word "Handler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +RBrace "}" +RBrace "}" +Slash "/" +Slash "/" +Word "Calling" +Word "use_string_handler" +Word "should" +Word "have" +Colon ":" +Throws "throws" +Word "string" +Function "function" +Word "test_string_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "use_string_handler" +LParen "(" +Word "make_string_handler" +LParen "(" +RParen ")" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "Calling" +Word "use_int_handler" +Word "should" +Word "have" +Colon ":" +Throws "throws" +Word "int" +Function "function" +Word "test_int_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "use_int_handler" +LParen "(" +Word "make_int_handler" +LParen "(" +RParen ")" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__heterogeneous_collection.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__heterogeneous_collection.snap new file mode 100644 index 0000000000..0e978ee3ac --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__heterogeneous_collection.snap @@ -0,0 +1,148 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Heterogeneous" +Word "stored" +Word "callback" +Word "collection" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Array" +Word "literal" +Word "should" +Word "widen" +Word "element" +Word "types" +Word "via" +Word "union" +Dot "." +Class "class" +Word "Task" +Less "<" +Word "E" +Greater ">" +LBrace "{" +Word "name" +Colon ":" +Word "string" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Throws "throws" +Word "E" +RBrace "}" +Function "function" +Word "run_tasks" +LParen "(" +Word "tasks" +Colon ":" +Word "Task" +Less "<" +Word "string" +Pipe "|" +Word "int" +Greater ">" +LBracket "[" +RBracket "]" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +For "for" +LParen "(" +Let "let" +Word "task" +In "in" +Word "tasks" +RParen ")" +LBrace "{" +Word "task" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Word "null" +RBrace "}" +Function "function" +Word "make_mixed_tasks" +LParen "(" +RParen ")" +Arrow "->" +Word "Task" +Less "<" +Word "string" +Pipe "|" +Word "int" +Greater ">" +LBracket "[" +RBracket "]" +LBrace "{" +LBracket "[" +Word "Task" +LBrace "{" +Word "name" +Colon ":" +Quote "\"" +Word "a" +Quote "\"" +Comma "," +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Quote "\"" +Word "error" +Quote "\"" +RBrace "}" +RBrace "}" +Comma "," +Word "Task" +LBrace "{" +Word "name" +Colon ":" +Quote "\"" +Word "b" +Quote "\"" +Comma "," +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +RBrace "}" +Comma "," +RBracket "]" +RBrace "}" +Function "function" +Word "test_run_mixed_tasks" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "run_tasks" +LParen "(" +Word "make_mixed_tasks" +LParen "(" +RParen ")" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_stored_callback.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_stored_callback.snap new file mode 100644 index 0000000000..8ba7ec39ed --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_stored_callback.snap @@ -0,0 +1,220 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + CLASS_DEF + KW_CLASS "class" + WORD "Handler" + GENERIC_PARAM_LIST + LESS "<" + GENERIC_PARAM "E" + WORD "E" + GREATER ">" + L_BRACE "{" + FIELD + WORD "run" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "E" + WORD "E" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "use_string_handler" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "h" + COLON ":" + TYPE_EXPR + WORD "Handler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + PATH_EXPR "h.run" + WORD "h" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "use_int_handler" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "h" + COLON ":" + TYPE_EXPR + WORD "Handler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "int" + WORD "int" + GREATER ">" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + PATH_EXPR "h.run" + WORD "h" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_string_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + WORD "Handler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OBJECT_LITERAL + WORD "Handler" + L_BRACE "{" + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "error" + QUOTE """ + WORD "error" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_int_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + WORD "Handler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "int" + WORD "int" + GREATER ">" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OBJECT_LITERAL + WORD "Handler" + L_BRACE "{" + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_string_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "use_string_handler" + CALL_ARGS + L_PAREN "(" + CALL_EXPR + WORD "make_string_handler" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_int_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "use_int_handler" + CALL_ARGS + L_PAREN "(" + CALL_EXPR + WORD "make_int_handler" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__heterogeneous_collection.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__heterogeneous_collection.snap new file mode 100644 index 0000000000..de7c62bfdd --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__heterogeneous_collection.snap @@ -0,0 +1,195 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + CLASS_DEF + KW_CLASS "class" + WORD "Task" + GENERIC_PARAM_LIST + LESS "<" + GENERIC_PARAM "E" + WORD "E" + GREATER ">" + L_BRACE "{" + FIELD + WORD "name" + COLON ":" + TYPE_EXPR "string" + WORD "string" + FIELD + WORD "run" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "E" + WORD "E" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "run_tasks" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "tasks" + COLON ":" + TYPE_EXPR + WORD "Task" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string | int" + WORD "string" + PIPE "|" + WORD "int" + GREATER ">" + L_BRACKET "[" + R_BRACKET "]" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + FOR_EXPR + KW_FOR "for" + L_PAREN "(" + LET_STMT "let task" + KW_LET "let" + WORD "task" + KW_IN "in" + WORD "tasks" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + PATH_EXPR "task.run" + WORD "task" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + WORD "null" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_mixed_tasks" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + WORD "Task" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string | int" + WORD "string" + PIPE "|" + WORD "int" + GREATER ">" + L_BRACKET "[" + R_BRACKET "]" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + ARRAY_LITERAL + L_BRACKET "[" + OBJECT_LITERAL + WORD "Task" + L_BRACE "{" + OBJECT_FIELD + WORD "name" + COLON ":" + STRING_LITERAL "a" + QUOTE """ + WORD "a" + QUOTE """ + COMMA "," + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "error" + QUOTE """ + WORD "error" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + COMMA "," + OBJECT_LITERAL + WORD "Task" + L_BRACE "{" + OBJECT_FIELD + WORD "name" + COLON ":" + STRING_LITERAL "b" + QUOTE """ + WORD "b" + QUOTE """ + COMMA "," + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + R_BRACE "}" + COMMA "," + R_BRACKET "]" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_run_mixed_tasks" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "run_tasks" + CALL_ARGS + L_PAREN "(" + CALL_EXPR + WORD "make_mixed_tasks" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index ff052a5853..6e4a42dd72 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 12292 --- === HIR2 === function user.test_array_map_pure() -> int[] [expr] { @@ -225,6 +224,40 @@ type user.Wrapper = (() -> int) -> int function user.use_alias_pure(f: user.AliasPure) -> int [expr] { { } f() } +class user.Handler { + run: () -> null +} +function user.make_int_handler() -> user.Handler [expr] { + { } user.Handler { run: () -> null { { throw 42 } } } +} +function user.make_string_handler() -> user.Handler [expr] { + { } user.Handler { run: () -> null { { throw "error" } } } +} +function user.test_int_handler() -> null [expr] { + { } use_int_handler(make_int_handler()) +} +function user.test_string_handler() -> null [expr] { + { } use_string_handler(make_string_handler()) +} +function user.use_int_handler(h: user.Handler) -> null [expr] { + { } h.run() +} +function user.use_string_handler(h: user.Handler) -> null [expr] { + { } h.run() +} +class user.Task { + name: string + run: () -> null +} +function user.make_mixed_tasks() -> user.Task[] [expr] { + { } [user.Task { name: "a", run: () -> null { { throw "error" } } }, user.Task { name: "b", run: () -> null { { throw 42 } } }] +} +function user.run_tasks(tasks: user.Task[]) -> null [expr] { + { for task in tasks { } task.run() } null +} +function user.test_run_mixed_tasks() -> null [expr] { + { } run_tasks(make_mixed_tasks()) +} function user.apply_guarded(f: () -> int) -> int [expr] { { let result = f(); if (result Lt 0) { throw "negative result" } } result } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index f8da30a8b9..92fc402c2f 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -1547,6 +1547,233 @@ fn user.use_alias_pure(f: () -> int) -> int { } } +fn user.make_int_handler() -> Handler { + // Locals: + let _0: Handler // _0 // return + let _1: () -> void throws int + + bb0: { + _1 = make_closure lambda[0](); + _0 = Handler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const 42_i64; + } +} + +fn user.make_string_handler() -> Handler { + // Locals: + let _0: Handler // _0 // return + let _1: () -> void throws string + + bb0: { + _1 = make_closure lambda[0](); + _0 = Handler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "error"; + } +} + +fn user.test_int_handler() -> null { + // Locals: + let _0: null // _0 // return + let _1: Handler + + bb0: { + _1 = call const fn user.make_int_handler() -> [bb1]; + } + + bb1: { + _0 = call const fn user.use_int_handler(copy _1) -> [bb2]; + } + + bb2: { + return; + } +} + +fn user.test_string_handler() -> null { + // Locals: + let _0: null // _0 // return + let _1: Handler + + bb0: { + _1 = call const fn user.make_string_handler() -> [bb1]; + } + + bb1: { + _0 = call const fn user.use_string_handler(copy _1) -> [bb2]; + } + + bb2: { + return; + } +} + +fn user.use_int_handler(h: Handler) -> null { + // Locals: + let _0: null // _0 // return + let _1: Handler // h // param + let _2: () -> null + + bb0: { + _2 = copy _1.0; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.use_string_handler(h: Handler) -> null { + // Locals: + let _0: null // _0 // return + let _1: Handler // h // param + let _2: () -> null + + bb0: { + _2 = copy _1.0; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.make_mixed_tasks() -> Task[] { + // Locals: + let _0: Task[] // _0 // return + let _1: Task + let _2: () -> void throws string + let _3: Task + let _4: () -> void throws int + + bb0: { + _2 = make_closure lambda[0](); + _1 = Task { const "a", copy _2 }; + _4 = make_closure lambda[1](); + _3 = Task { const "b", copy _4 }; + _0 = [copy _1, copy _3]; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "error"; + } +} + +// lambda[1] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const 42_i64; + } +} + +fn user.run_tasks(tasks: Task[]) -> null { + // Locals: + let _0: null // _0 // return + let _1: Task[] // tasks // param + let _2: int // __for_idx + let _3: int + let _4: bool + let _5: Task + let _6: Task // task + let _7: void + let _8: () -> null + let _9: Task + + bb0: { + _2 = const 0_i64; + goto -> bb1; + } + + bb1: { + _3 = len(_1); + _4 = copy _2 < copy _3; + branch copy _4 -> [bb4, bb2]; + } + + bb2: { + _0 = const null; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + _5 = copy _1[_2]; + _6 = copy _5; + _9 = copy _6; + _8 = copy _9.1; + _7 = call copy _8() -> [bb5]; + } + + bb5: { + _2 = copy _2 + const 1_i64; + goto -> bb1; + } +} + +fn user.test_run_mixed_tasks() -> null { + // Locals: + let _0: null // _0 // return + let _1: Task[] + + bb0: { + _1 = call const fn user.make_mixed_tasks() -> [bb1]; + } + + bb1: { + _0 = call const fn user.run_tasks(copy _1) -> [bb2]; + } + + bb2: { + return; + } +} + fn user.apply_guarded(f: () -> int) -> int { // Locals: let _0: int // _0 // return diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 6319a722cd..e8a9cd9a12 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -638,6 +638,77 @@ type user.NoThrow$stream = never type user.AliasPure$stream = unknown type user.Mapper$stream = unknown type user.Wrapper$stream = unknown +class user.Handler { + run: () -> null throws E +} +function user.use_string_handler(h: user.Handler) -> null throws never { + { : null + h.run() : null + } +} +function user.use_int_handler(h: user.Handler) -> null throws never { + { : null + h.run() : null + } +} +function user.make_string_handler() -> user.Handler throws never { + { : user.Handler + Handler { run: () -> null { ... } } : user.Handler + } +} +lambda user.make_string_handler { +} +function user.make_int_handler() -> user.Handler throws never { + { : user.Handler + Handler { run: () -> null { ... } } : user.Handler + } +} +lambda user.make_int_handler { +} +function user.test_string_handler() -> null throws never { + { : null + use_string_handler(make_string_handler()) : null + } +} +function user.test_int_handler() -> null throws never { + { : null + use_int_handler(make_int_handler()) : null + } +} +class user.Handler$stream { + run: unknown +} +class user.Task { + name: string + run: () -> null throws E +} +function user.run_tasks(tasks: user.Task[]) -> null throws never { + { : null + for task in tasks + { : null + task.run() : null + } + null : null + } +} +function user.make_mixed_tasks() -> user.Task[] throws never { + { : user.Task[] + [Task { name: "a", run: () -> null { ... } }, Task { name: "b", run: () -> null { ... } }] : user.Task[] + } +} +lambda user.make_mixed_tasks { +} +lambda user.make_mixed_tasks { +} +function user.test_run_mixed_tasks() -> null throws never { + { : null + run_tasks(make_mixed_tasks()) : null + } +} +class user.Task$stream { + name: null | string + run: unknown +} function user.apply_guarded(f: () -> int throws __throws_f) -> int throws __throws_f | string { { : int let result = f() : int diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 505bc2bcfc..2462701429 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -347,6 +347,28 @@ function user.make_good_stored_handler() -> StoredPureHandler { return } +function user.make_int_handler() -> Handler { + alloc_instance Handler + make_closure ., 0 + init_field .run + return +} + +function user.make_mixed_tasks() -> Task[] { + alloc_instance Task + load_const "a" + init_field .name + make_closure ., 0 + init_field .run + alloc_instance Task + load_const "b" + init_field .name + make_closure ., 0 + init_field .run + alloc_array 2 + return +} + function user.make_pure() -> () -> int { make_closure ., 0 return @@ -381,6 +403,13 @@ function user.make_stored_throwing_handler() -> StoredThrowingHandler { return } +function user.make_string_handler() -> Handler { + alloc_instance Handler + make_closure ., 0 + init_field .run + return +} + function user.make_thrower() -> () -> int throws string { make_closure ., 0 return @@ -448,6 +477,36 @@ function user.run_pure(f: () -> int) -> int { return } +function user.run_tasks(tasks: Task[]) -> null { + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var tasks + call baml.Array.length + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_const null + return + + L2: + load_var tasks + load_var __for_idx + load_array_element + load_field .run + call_indirect + pop 1 + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 +} + function user.run_throwing(f: () -> int) -> int { load_var f call_indirect @@ -707,6 +766,12 @@ function user.test_ignore_callback_throwing() -> int { return } +function user.test_int_handler() -> null { + call user.make_int_handler + call user.use_int_handler + return +} + function user.test_lambda_throws_class() -> int { make_closure ., 0 call_indirect @@ -825,6 +890,12 @@ function user.test_rethrows_enum() -> int { return } +function user.test_run_mixed_tasks() -> null { + call user.make_mixed_tasks + call user.run_tasks + return +} + function user.test_run_pure() -> int { make_closure ., 0 call user.run_pure @@ -873,6 +944,12 @@ function user.test_stored_local_with_throws() -> null { return } +function user.test_string_handler() -> null { + call user.make_string_handler + call user.use_string_handler + return +} + function user.test_throwing_int() -> string { make_closure ., 0 call_indirect @@ -1182,8 +1259,22 @@ function user.use_enum_thrower(f: () -> int throws ErrorKind) -> int { return } +function user.use_int_handler(h: Handler) -> null { + load_var h + load_field .run + call_indirect + return +} + function user.use_mixed_thrower(f: () -> int throws ErrorKind | ApiError | string) -> int { load_var f call_indirect return } + +function user.use_string_handler(h: Handler) -> null { + load_var h + load_field .run + call_indirect + return +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_stored_callback.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_stored_callback.snap new file mode 100644 index 0000000000..a210b46ea0 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_stored_callback.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +An element at generic_stored_callback.baml:4:14 was a token when it should have been a node. diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__heterogeneous_collection.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__heterogeneous_collection.snap new file mode 100644 index 0000000000..f504b6b85b --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__heterogeneous_collection.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +An element at heterogeneous_collection.baml:4:11 was a token when it should have been a node. From 5333f74cab7a1820d0a8ca96272fe19cabcd0a3f Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 07:52:00 -0500 Subject: [PATCH 18/26] Add generic nominal instantiation soundness and cross-package resolution Phase 1: StructuralTy::Class and StructuralTy::Enum now carry type_args, enabling invariant subtype comparison for instantiated nominals (e.g. Handler vs Handler are distinct). Shared instantiate_alias helper funnels all six alias expansion sites. Generic inference helpers (contains_typevar, infer_bindings, erase_unresolved_typevars) recurse into nominal type_args. Phase 2: Cross-package class field and method resolution routed through PackageResolutionContext with generic substitution applied inside the methods. Builder delegates to res_ctx instead of bypassing PackageInterface. Fixed cross-namespace struct literal resolution fallback. New test fixtures for generic subtype rejection and generic function inference over nominal wrappers. --- .../crates/baml_compiler2_ast/src/ast.rs | 2 + .../baml_compiler2_ast/src/companions.rs | 1 + .../crates/baml_compiler2_ast/src/lib.rs | 10 +- .../baml_compiler2_ast/src/lower_cst.rs | 8 + .../baml_compiler2_ast/src/lower_expr_body.rs | 17 + .../baml_compiler2_ast/src/lower_type_expr.rs | 28 +- .../crates/baml_compiler2_mir/src/lower.rs | 4 +- .../crates/baml_compiler2_ppir/src/expand.rs | 30 +- .../crates/baml_compiler2_ppir/src/ty.rs | 20 +- .../crates/baml_compiler2_tir/src/builder.rs | 424 +++++++++++++----- .../crates/baml_compiler2_tir/src/generics.rs | 136 +++++- .../baml_compiler2_tir/src/inference.rs | 2 +- .../baml_compiler2_tir/src/lower_type_expr.rs | 32 +- .../baml_compiler2_tir/src/normalize.rs | 135 ++++-- .../src/package_interface.rs | 142 +++++- .../baml_compiler2_tir/src/throw_inference.rs | 12 +- .../src/throws_semantics.rs | 11 +- .../crates/baml_compiler2_tir/src/ty.rs | 117 ++++- .../src/control_flow/from_ast.rs | 4 + .../baml_lsp2_actions/src/completions.rs | 6 +- .../generic_function_inference.baml | 18 + .../generic_subtype_rejection.baml | 15 + .../baml_tests____baml_std____04_5_mir.snap | 4 +- ...baml_tests____testing_std____04_5_mir.snap | 6 +- .../baml_tests__format_checks__04_tir.snap | 2 +- ..._01_lexer__generic_function_inference.snap | 101 +++++ ...__01_lexer__generic_subtype_rejection.snap | 89 ++++ ...02_parser__generic_function_inference.snap | 122 +++++ ..._02_parser__generic_subtype_rejection.snap | 81 ++++ ...l_tests__function_type_throws__03_hir.snap | 21 + ...tests__function_type_throws__04_5_mir.snap | 83 +++- ...l_tests__function_type_throws__04_tir.snap | 66 ++- ..._function_type_throws__05_diagnostics.snap | 47 ++ ...sts__function_type_throws__06_codegen.snap | 36 ++ ...formatter__generic_function_inference.snap | 5 + ..._formatter__generic_subtype_rejection.snap | 5 + ...s__test_with_runner_ambiguity__04_tir.snap | 1 + ...with_runner_ambiguity__05_diagnostics.snap | 11 + .../baml_tests__type_aliases__04_tir.snap | 2 +- ...l_tests__type_aliases__05_diagnostics.snap | 14 +- .../baml_tests/src/compiler2_tir/mod.rs | 3 +- .../baml_tests/src/compiler2_tir/phase5.rs | 6 + 42 files changed, 1644 insertions(+), 235 deletions(-) create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/generic_function_inference.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/generic_subtype_rejection.baml create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_function_inference.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_subtype_rejection.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_function_inference.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_subtype_rejection.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_function_inference.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_subtype_rejection.snap diff --git a/baml_language/crates/baml_compiler2_ast/src/ast.rs b/baml_language/crates/baml_compiler2_ast/src/ast.rs index 298ba428b2..ef55da7282 100644 --- a/baml_language/crates/baml_compiler2_ast/src/ast.rs +++ b/baml_language/crates/baml_compiler2_ast/src/ast.rs @@ -41,6 +41,7 @@ pub enum TypeExpr { /// Named type path: `User`, `baml.http.Request` Path { segments: Vec, + type_args: Vec, attrs: Vec, }, /// Primitive types @@ -446,6 +447,7 @@ pub enum Expr { }, Object { type_name: Option, + type_args: Vec, fields: Vec<(Name, ExprId)>, spreads: Vec, }, diff --git a/baml_language/crates/baml_compiler2_ast/src/companions.rs b/baml_language/crates/baml_compiler2_ast/src/companions.rs index ff6fab4849..785bc0355d 100644 --- a/baml_language/crates/baml_compiler2_ast/src/companions.rs +++ b/baml_language/crates/baml_compiler2_ast/src/companions.rs @@ -100,6 +100,7 @@ fn make_llm_companion( let return_type = SpannedTypeExpr { expr: TypeExpr::Path { segments: return_type_path.iter().map(Name::new).collect(), + type_args: vec![], attrs: vec![], }, span: parent.span, diff --git a/baml_language/crates/baml_compiler2_ast/src/lib.rs b/baml_language/crates/baml_compiler2_ast/src/lib.rs index 62872e5753..c5265e000e 100644 --- a/baml_language/crates/baml_compiler2_ast/src/lib.rs +++ b/baml_language/crates/baml_compiler2_ast/src/lib.rs @@ -65,6 +65,7 @@ mod tests { (Path($name:expr $(, Attr($a:expr))*)) => { TypeExpr::Path { segments: vec![baml_base::Name::new($name)], + type_args: vec![], attrs: type_expr!(@attrs $(, Attr($a))*), } }; @@ -153,8 +154,13 @@ mod tests { TypeExpr::Rust { attrs } => TypeExpr::Rust { attrs: strip_attrs(attrs), }, - TypeExpr::Path { segments, attrs } => TypeExpr::Path { + TypeExpr::Path { + segments, + type_args, + attrs, + } => TypeExpr::Path { segments: segments.clone(), + type_args: type_args.iter().map(strip_spans).collect(), attrs: strip_attrs(attrs), }, TypeExpr::Optional { inner, attrs } => TypeExpr::Optional { @@ -292,6 +298,7 @@ class Response { baml_base::Name::new("errors"), baml_base::Name::new("Io"), ], + type_args: vec![], attrs: vec![] } ); @@ -972,6 +979,7 @@ retry_policy MyRetry { let (type_name, fields, _) = match root_expr { Expr::Object { type_name, + type_args: _, fields, spreads, } => (type_name, fields, spreads), diff --git a/baml_language/crates/baml_compiler2_ast/src/lower_cst.rs b/baml_language/crates/baml_compiler2_ast/src/lower_cst.rs index 1459b23ca4..1e4e13c5e2 100644 --- a/baml_language/crates/baml_compiler2_ast/src/lower_cst.rs +++ b/baml_language/crates/baml_compiler2_ast/src/lower_cst.rs @@ -419,6 +419,7 @@ pub(crate) fn synthesize_llm_builtin_call( let counter = alloc(Expr::Literal(Literal::Int(0))); alloc(Expr::Object { type_name: Some(Name::new("baml.llm.Client")), + type_args: vec![], fields: vec![ (Name::new("name"), name_lit), (Name::new("client_type"), ct_variant), @@ -953,6 +954,7 @@ fn synthesize_init_test_function( type_expr: Some(SpannedTypeExpr { expr: crate::ast::TypeExpr::Path { segments: vec![Name::new("testing"), Name::new("TestCollector")], + type_args: vec![], attrs: vec![], }, span, @@ -1053,6 +1055,7 @@ fn synthesize_register_call( type_expr: Some(SpannedTypeExpr { expr: crate::ast::TypeExpr::Path { segments: vec![Name::new("testing"), Name::new("TestCollector")], + type_args: vec![], attrs: vec![], }, span, @@ -1229,6 +1232,7 @@ fn synthesize_retry_policy_let( let root = alloc(Expr::Object { type_name: Some(Name::new("RetryPolicy")), + type_args: vec![], fields, spreads: vec![], }); @@ -1466,6 +1470,7 @@ fn synthesize_client_let( // Client { name, client_type, sub_clients, retry, counter } let root = alloc(Expr::Object { type_name: Some(Name::new("Client")), + type_args: vec![], fields: vec![ (Name::new("name"), name_expr), (Name::new("client_type"), client_type_expr), @@ -1677,6 +1682,7 @@ fn synthesize_client_new_companion( if !provider_fields_set.is_empty() { alloc(Expr::Object { type_name: Some(Name::new(type_name)), + type_args: vec![], fields: prov_fields, spreads: vec![], }) @@ -1716,6 +1722,7 @@ fn synthesize_client_new_companion( let options_expr = alloc(Expr::Object { type_name: Some(Name::new("baml.llm.PrimitiveClientOptions")), + type_args: vec![], fields: options_fields, spreads: vec![], }); @@ -1727,6 +1734,7 @@ fn synthesize_client_new_companion( ))); let root = alloc(Expr::Object { type_name: Some(Name::new("baml.llm.PrimitiveClient")), + type_args: vec![], fields: vec![ (Name::new("name"), name_lit), (Name::new("provider"), provider_lit), diff --git a/baml_language/crates/baml_compiler2_ast/src/lower_expr_body.rs b/baml_language/crates/baml_compiler2_ast/src/lower_expr_body.rs index f823671e7d..f7c8af9c3b 100644 --- a/baml_language/crates/baml_compiler2_ast/src/lower_expr_body.rs +++ b/baml_language/crates/baml_compiler2_ast/src/lower_expr_body.rs @@ -1088,6 +1088,7 @@ impl LoweringContext { name, ty: crate::ast::TypeExpr::Path { segments: vec![Name::new(&text)], + type_args: vec![], attrs: vec![], }, }; @@ -1990,6 +1991,7 @@ impl LoweringContext { let mut spreads = Vec::new(); let mut position = 0; let mut type_name = None; + let mut type_args = Vec::new(); // Look for the optional type name (first WORD or path before the brace). // The type name may be: @@ -2021,6 +2023,19 @@ impl LoweringContext { if let Some(name) = last_word { type_name = Some(name); } + type_args = child_node + .children() + .find(|n| n.kind() == SyntaxKind::GENERIC_ARGS) + .into_iter() + .flat_map(|args_node| { + args_node + .children() + .filter(|n| n.kind() == SyntaxKind::TYPE_EXPR) + .filter_map(baml_compiler_syntax::ast::TypeExpr::cast) + .map(|arg| crate::lower_type_expr::lower_type_expr_node(&arg)) + .collect::>() + }) + .collect(); // After handling the path node, stop scanning for more pre-brace items. break 'outer; } @@ -2086,6 +2101,7 @@ impl LoweringContext { self.alloc_expr( Expr::Object { type_name, + type_args, fields, spreads, }, @@ -2840,6 +2856,7 @@ impl LoweringContext { type_expr: Some(SpannedTypeExpr { expr: TypeExpr::Path { segments: vec![Name::new("testing"), Name::new("TestCollector")], + type_args: vec![], attrs: vec![], }, span, diff --git a/baml_language/crates/baml_compiler2_ast/src/lower_type_expr.rs b/baml_language/crates/baml_compiler2_ast/src/lower_type_expr.rs index 61bfbaf971..680254e51c 100644 --- a/baml_language/crates/baml_compiler2_ast/src/lower_type_expr.rs +++ b/baml_language/crates/baml_compiler2_ast/src/lower_type_expr.rs @@ -252,7 +252,13 @@ fn lower_base_type(type_expr: &CstTypeExpr) -> TypeExpr { } // Named type (primitive or user-defined) - return lower_from_type_name(&name); + let args = type_expr.type_arg_exprs(); + return lower_from_type_name( + &name, + args.iter() + .map(|arg| lower_type_expr_inner(arg, false)) + .collect(), + ); } TypeExpr::Unknown { attrs: vec![] } @@ -343,7 +349,21 @@ fn lower_union_member_base(parts: &baml_compiler_syntax::ast::UnionMemberParts) value: baml_base::Literal::Bool(false), attrs: vec![], }, - _ => lower_from_type_name(&name), + _ => lower_from_type_name( + &name, + parts + .type_args() + .into_iter() + .flat_map(|type_args_node| { + type_args_node + .children() + .filter(|n| n.kind() == baml_compiler_syntax::SyntaxKind::TYPE_EXPR) + .filter_map(baml_compiler_syntax::ast::TypeExpr::cast) + .map(|arg| lower_type_expr_inner(&arg, false)) + .collect::>() + }) + .collect(), + ), }; } @@ -351,7 +371,7 @@ fn lower_union_member_base(parts: &baml_compiler_syntax::ast::UnionMemberParts) } /// Create a `TypeExpr` from a type name string (primitive or user-defined). -fn lower_from_type_name(name: &str) -> TypeExpr { +fn lower_from_type_name(name: &str, type_args: Vec) -> TypeExpr { match name { "int" => TypeExpr::Int { attrs: vec![] }, "float" => TypeExpr::Float { attrs: vec![] }, @@ -384,11 +404,13 @@ fn lower_from_type_name(name: &str) -> TypeExpr { let segments: Vec = name.split('.').map(Name::new).collect(); TypeExpr::Path { segments, + type_args, attrs: vec![], } } else { TypeExpr::Path { segments: vec![Name::new(name)], + type_args, attrs: vec![], } } diff --git a/baml_language/crates/baml_compiler2_mir/src/lower.rs b/baml_language/crates/baml_compiler2_mir/src/lower.rs index 496148a921..149bfa5c6f 100644 --- a/baml_language/crates/baml_compiler2_mir/src/lower.rs +++ b/baml_language/crates/baml_compiler2_mir/src/lower.rs @@ -963,7 +963,8 @@ impl LoweringContext<'_> { let tir_ty = baml_compiler2_tir::ty::Ty::Class( baml_compiler2_tir::lower_type_expr::qualify_def( self.db, def, cn, - ), + ) + .into(), baml_compiler2_tir::ty::TyAttr::default(), ); self.resolved_aliases.convert(&tir_ty) @@ -1456,6 +1457,7 @@ impl LoweringContext<'_> { type_name, fields, spreads, + .. } => { self.lower_object(expr_id, type_name.as_ref(), &fields, &spreads, dest); } diff --git a/baml_language/crates/baml_compiler2_ppir/src/expand.rs b/baml_language/crates/baml_compiler2_ppir/src/expand.rs index 595fe5b2cf..6f5d04da46 100644 --- a/baml_language/crates/baml_compiler2_ppir/src/expand.rs +++ b/baml_language/crates/baml_compiler2_ppir/src/expand.rs @@ -140,13 +140,21 @@ fn requalify_for_caller(ty: PpirTy, alias_ns: &[Name], caller_ns: &[Name]) -> Pp return ty; } match ty { - PpirTy::Named { path, attrs } if path.len() == 1 && path[0].as_str() != "root" => { + PpirTy::Named { + path, + type_args, + attrs, + } if path.len() == 1 && path[0].as_str() != "root" => { let mut qualified = Vec::with_capacity(alias_ns.len() + 2); qualified.push(SmolStr::from("root")); qualified.extend(alias_ns.iter().cloned()); qualified.extend(path); PpirTy::Named { path: qualified, + type_args: type_args + .into_iter() + .map(|arg| requalify_for_caller(arg, alias_ns, caller_ns)) + .collect(), attrs, } } @@ -276,7 +284,9 @@ pub fn expand_partial(ty: &PpirTy, ctx: &ExpandCtx<'_>) -> PpirTy { | PpirTy::CannotBeStreamed { .. } => ty.clone_without_attrs(), // Named types — depends on classification - PpirTy::Named { path, .. } => { + PpirTy::Named { + path, type_args, .. + } => { // Already *$stream → unchanged if path.last().is_some_and(|n| n.as_str().ends_with("$stream")) { return ty.clone_without_attrs(); @@ -291,6 +301,10 @@ pub fn expand_partial(ty: &PpirTy, ctx: &ExpandCtx<'_>) -> PpirTy { .cloned() .chain(std::iter::once(SmolStr::new(format!("{bare_name}$stream")))) .collect(), + type_args: type_args + .iter() + .map(|arg| expand_partial(arg, ctx)) + .collect(), attrs: d, } } @@ -408,7 +422,9 @@ fn stream_expand_inner(ty: &PpirTy, ctx: &ExpandCtx<'_>, depth: u32) -> (PpirTy, ), // Named types - PpirTy::Named { path, .. } => { + PpirTy::Named { + path, type_args, .. + } => { // Already *$stream → treat like T$stream if path.last().is_some_and(|n| n.as_str().ends_with("$stream")) { ( @@ -439,6 +455,10 @@ fn stream_expand_inner(ty: &PpirTy, ctx: &ExpandCtx<'_>, depth: u32) -> (PpirTy, ( PpirTy::Named { path: stream_path, + type_args: type_args + .iter() + .map(|arg| expand_partial(arg, ctx)) + .collect(), attrs: d.clone(), }, DefaultWhenPending::PrependNull, @@ -490,6 +510,10 @@ fn stream_expand_inner(ty: &PpirTy, ctx: &ExpandCtx<'_>, depth: u32) -> (PpirTy, ( PpirTy::Named { path: stream_path, + type_args: type_args + .iter() + .map(|arg| expand_partial(arg, ctx)) + .collect(), attrs: d.clone(), }, DefaultWhenPending::PrependNull, diff --git a/baml_language/crates/baml_compiler2_ppir/src/ty.rs b/baml_language/crates/baml_compiler2_ppir/src/ty.rs index 264d3fa1c5..f7cfe70545 100644 --- a/baml_language/crates/baml_compiler2_ppir/src/ty.rs +++ b/baml_language/crates/baml_compiler2_ppir/src/ty.rs @@ -47,6 +47,7 @@ pub struct PpirTypeAttrs { pub enum PpirTy { Named { path: Vec, + type_args: Vec, attrs: PpirTypeAttrs, }, Int { @@ -139,8 +140,11 @@ impl PpirTy { pub fn clone_without_attrs(&self) -> Self { let d = PpirTypeAttrs::default(); match self { - Self::Named { path, .. } => Self::Named { + Self::Named { + path, type_args, .. + } => Self::Named { path: path.clone(), + type_args: type_args.clone(), attrs: d, }, Self::Int { .. } => Self::Int { attrs: d }, @@ -182,6 +186,7 @@ impl PpirTy { pub fn named(name: impl Into) -> Self { PpirTy::Named { path: vec![name.into()], + type_args: vec![], attrs: PpirTypeAttrs::default(), } } @@ -239,8 +244,13 @@ impl PpirTy { TypeExpr::Bool { .. } => PpirTy::Bool { attrs }, TypeExpr::Null { .. } => PpirTy::Null { attrs }, TypeExpr::Never { .. } => PpirTy::Never { attrs }, - TypeExpr::Path { segments, .. } => PpirTy::Named { + TypeExpr::Path { + segments, + type_args, + .. + } => PpirTy::Named { path: segments.clone(), + type_args: type_args.iter().map(Self::convert_type_expr).collect(), attrs, }, TypeExpr::Optional { inner, .. } => PpirTy::Optional { @@ -299,8 +309,11 @@ impl PpirTy { /// Convert a `PpirTy` back to a `TypeExpr` for synthesized AST items. pub fn to_type_expr(&self) -> TypeExpr { match self { - PpirTy::Named { path, .. } => TypeExpr::Path { + PpirTy::Named { + path, type_args, .. + } => TypeExpr::Path { segments: path.clone(), + type_args: type_args.iter().map(PpirTy::to_type_expr).collect(), attrs: vec![], }, PpirTy::Int { .. } => TypeExpr::Int { attrs: vec![] }, @@ -383,6 +396,7 @@ mod tests { fn ppir_reads_stream_done_from_type_expr() { let type_expr = TypeExpr::Path { segments: vec![Name::new("Fizz")], + type_args: vec![], attrs: vec![make_attr("stream.done")], }; let ppir_ty = PpirTy::from_type_expr(&type_expr); diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index da245af2c1..5bc50e5443 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -470,26 +470,37 @@ impl<'db> TypeInferenceBuilder<'db> { } } Expr::Object { - type_name, fields, .. + type_name, + type_args, + fields, + .. } => { - for (_, expr_id) in fields { - self.infer_expr(*expr_id, body); - } - type_name + let resolved_ty = type_name .as_ref() - .and_then(|n| { - self.package_items - .lookup_type(&self.ns_context, n) - .map(|def| { - Ty::Class( - crate::lower_type_expr::qualify_def(self.context.db(), def, n), - TyAttr::default(), - ) - }) - }) - .unwrap_or(Ty::Unknown { + .and_then(|name| self.resolve_named_object_ty(name, type_args, expr_id)); + + if let Some(Ty::Class(class_name, _)) = resolved_ty.as_ref() { + let field_types = self.lookup_class_fields(class_name); + for (field_name, field_expr) in fields { + if let Some(declared_ty) = field_types.get(field_name) { + self.check_expr(*field_expr, body, declared_ty); + } else { + self.infer_expr(*field_expr, body); + } + } + let ty = resolved_ty.expect("resolved class type"); + self.record_expr_type(expr_id, ty.clone()); + ty + } else { + for (_, field_expr) in fields { + self.infer_expr(*field_expr, body); + } + let ty = Ty::Unknown { attr: TyAttr::default(), - }) + }; + self.record_expr_type(expr_id, ty.clone()); + ty + } } Expr::Index { base, index } => { let base_ty = self.infer_expr(*base, body); @@ -841,8 +852,44 @@ impl<'db> TypeInferenceBuilder<'db> { } } // Object: if expected is Class(name), check fields against declared types. - Expr::Object { fields, .. } => { - if let Ty::Class(class_name, _) = expected { + Expr::Object { + type_name, + type_args, + fields, + .. + } => { + if let Some(type_name) = type_name { + if let Some(inferred_ty) = self.resolve_named_object_ty_with_expected( + type_name, type_args, expected, expr_id, + ) { + if let Ty::Class(class_name, _) = &inferred_ty { + let field_types = self.lookup_class_fields(class_name); + for (field_name, field_expr) in fields { + if let Some(declared_ty) = field_types.get(field_name) { + self.check_expr(*field_expr, body, declared_ty); + } else { + self.infer_expr(*field_expr, body); + } + } + } else { + for (_, field_expr) in fields { + self.infer_expr(*field_expr, body); + } + } + + if !self.is_subtype(&inferred_ty, expected) { + self.report_type_mismatch(expected, &inferred_ty, expr_id); + } + self.record_expr_type(expr_id, inferred_ty.clone()); + inferred_ty + } else { + let inferred = self.infer_expr(expr_id, body); + if !self.is_subtype(&inferred, expected) { + self.report_type_mismatch(expected, &inferred, expr_id); + } + inferred + } + } else if let Ty::Class(class_name, _) = expected { let field_types = self.lookup_class_fields(class_name); for (field_name, field_expr) in fields { if let Some(declared_ty) = field_types.get(field_name) { @@ -2116,6 +2163,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.lower_pattern_type_expr( &TypeExpr::Path { segments: vec![name.clone()], + type_args: vec![], attrs: vec![], }, at_expr, @@ -2138,7 +2186,11 @@ impl<'db> TypeInferenceBuilder<'db> { baml_compiler2_ast::Pattern::EnumVariant { enum_name, variant } => { if let Ty::Enum(qn, _) = scrutinee_ty { if Self::enum_name_matches(enum_name, qn) { - return Ty::EnumVariant(qn.clone(), variant.clone(), TyAttr::default()); + return Ty::EnumVariant( + qn.qtn().clone(), + variant.clone(), + TyAttr::default(), + ); } } if let Some(def) = self.package_items.lookup_type(&self.ns_context, enum_name) { @@ -2180,6 +2232,93 @@ impl<'db> TypeInferenceBuilder<'db> { ty } + fn resolve_named_object_ty( + &mut self, + type_name: &Name, + type_args: &[TypeExpr], + at_expr: ExprId, + ) -> Option { + let def = self + .package_items + .lookup_type(&self.ns_context, type_name)?; + if !matches!(def, Definition::Class(_)) { + return None; + } + + let mut diags = Vec::new(); + let lowered_type_args = type_args + .iter() + .map(|arg| { + crate::lower_type_expr::lower_type_expr_in_ns( + self.context.db(), + arg, + self.package_items, + &self.ns_context, + &self.generic_params, + &mut diags, + ) + }) + .collect(); + for diag in diags { + self.context.report_simple(diag, at_expr); + } + + Some( + Ty::Class( + crate::lower_type_expr::qualify_def(self.context.db(), def, type_name).into(), + TyAttr::default(), + ) + .with_nominal_type_args(lowered_type_args), + ) + } + + fn resolve_named_object_ty_with_expected( + &mut self, + type_name: &Name, + type_args: &[TypeExpr], + expected: &Ty, + at_expr: ExprId, + ) -> Option { + // First try to resolve the type name from the current namespace context. + if let Some(resolved) = self.resolve_named_object_ty(type_name, type_args, at_expr) { + if type_args.is_empty() + && let (Ty::Class(resolved_class, _), Ty::Class(expected_class, _)) = + (&resolved, expected) + && resolved_class.qtn() == expected_class.qtn() + { + return Some(expected.clone()); + } + return Some(resolved); + } + + // Fallback: if the type_name couldn't be resolved from this namespace context, + // check whether `expected` is a class whose unqualified name matches `type_name`. + // + // Two cases are handled: + // (a) Cross-namespace struct literal: `root.llm.Response { ... }` in a root-namespace + // file — the AST only stores the last segment "Response" as `type_name`. + // (b) Fully-qualified synthetic type_name from CST lowering: `client "openai/gpt-4o"` + // is lowered with `type_name = "baml.llm.Client"` (full dotted name in one Name). + // `expected_class.name()` is just "Client", so we check if `type_name` ends with + // `"."` or equals ``. + // + // In both cases we use `expected` as the resolved type, preserving the old behaviour + // where the check path propagated the expected type directly. + if type_args.is_empty() { + if let Ty::Class(expected_class, _) = expected { + let expected_name = expected_class.name().as_str(); + let tn = type_name.as_str(); + let name_matches = + tn == expected_name || tn.ends_with(&format!(".{expected_name}")); + if name_matches { + return Some(expected.clone()); + } + } + } + + None + } + fn literal_case_name(lit: &baml_base::Literal) -> String { match lit { baml_base::Literal::Int(v) => v.to_string(), @@ -2209,6 +2348,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.lower_pattern_type_expr( &TypeExpr::Path { segments: vec![name.clone()], + type_args: vec![], attrs: vec![], }, at_expr, @@ -2244,8 +2384,14 @@ impl<'db> TypeInferenceBuilder<'db> { match ty { Ty::Class(qtn, _) => qtn.is_panic_type().then(|| ty.clone()), Ty::TypeAlias(qtn, _) => { - if let Some(expanded) = self.aliases.get(qtn) { - self.ty_panic_subset(expanded) + if let Some(alias_body) = self.aliases.get(qtn.qtn()) { + let expanded = crate::normalize::instantiate_alias( + qtn.qtn(), + alias_body, + qtn.type_args(), + &self.aliases, + ); + self.ty_panic_subset(&expanded) } else if qtn.is_panic_type() { Some(ty.clone()) } else { @@ -2425,6 +2571,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.lower_pattern_type_expr( &TypeExpr::Path { segments: vec![name.clone()], + type_args: vec![], attrs: vec![], }, at_expr, @@ -2808,11 +2955,11 @@ impl<'db> TypeInferenceBuilder<'db> { match def { Definition::Class(_) => { let class_qtn = crate::lower_type_expr::qualify_def(db, def, name); - return Some(Ty::Class(class_qtn, TyAttr::default())); + return Some(Ty::Class(class_qtn.into(), TyAttr::default())); } Definition::Enum(_) => { let enum_qtn = crate::lower_type_expr::qualify_def(db, def, name); - return Some(Ty::Enum(enum_qtn, TyAttr::default())); + return Some(Ty::Enum(enum_qtn.into(), TyAttr::default())); } _ => {} } @@ -2873,15 +3020,15 @@ impl<'db> TypeInferenceBuilder<'db> { let db = self.context.db(); match def { Definition::Class(_) => Ty::Class( - crate::lower_type_expr::qualify_def(db, def, name), + crate::lower_type_expr::qualify_def(db, def, name).into(), TyAttr::default(), ), Definition::Enum(_) => Ty::Enum( - crate::lower_type_expr::qualify_def(db, def, name), + crate::lower_type_expr::qualify_def(db, def, name).into(), TyAttr::default(), ), Definition::TypeAlias(_) => Ty::TypeAlias( - crate::lower_type_expr::qualify_def(db, def, name), + crate::lower_type_expr::qualify_def(db, def, name).into(), TyAttr::default(), ), _ => Ty::Unknown { @@ -2969,7 +3116,11 @@ impl<'db> TypeInferenceBuilder<'db> { }, ); } - return Ty::EnumVariant(enum_name.clone(), member.clone(), TyAttr::default()); + return Ty::EnumVariant( + enum_name.qtn().clone(), + member.clone(), + TyAttr::default(), + ); } // Known enum but variant not found — error @@ -3108,8 +3259,13 @@ impl<'db> TypeInferenceBuilder<'db> { } Ty::TypeAlias(qtn, _) => { // Expand the alias to its concrete type, then recurse. - if let Some(expanded) = self.aliases.get(qtn) { - let expanded = expanded.clone(); + if let Some(alias_body) = self.aliases.get(qtn.qtn()) { + let expanded = crate::normalize::instantiate_alias( + qtn.qtn(), + alias_body, + qtn.type_args(), + &self.aliases, + ); return self.resolve_member(&expanded, member, at); } // Alias not in map (cyclic or unresolved) — treat as Unknown @@ -3185,8 +3341,13 @@ impl<'db> TypeInferenceBuilder<'db> { self.try_resolve_member_on_ty(inner, member) } Ty::TypeAlias(qtn, _) => { - if let Some(expanded) = self.aliases.get(qtn) { - let expanded = expanded.clone(); + if let Some(alias_body) = self.aliases.get(qtn.qtn()) { + let expanded = crate::normalize::instantiate_alias( + qtn.qtn(), + alias_body, + qtn.type_args(), + &self.aliases, + ); self.try_resolve_member_on_ty(&expanded, member) } else { None @@ -3202,52 +3363,20 @@ impl<'db> TypeInferenceBuilder<'db> { } } - /// Look up class fields from the package items (via item tree). + /// Look up class fields, routing through `PackageResolutionContext`. + /// + /// For own-package classes this lowers fields from the `ItemTree` and applies + /// class-level generic substitution. For dependency-package classes it reads + /// from the pre-resolved `PackageInterface` (which already carries lowered `Ty` + /// values) and applies substitution there. /// /// Returns a map of field name → resolved field type. - fn lookup_class_fields( - &self, - class_name: &crate::ty::QualifiedTypeName, - ) -> FxHashMap { - let mut result = FxHashMap::default(); - let Some(pkg_items_for_class) = self.resolve_class_pkg_items(class_name.package()) else { - return result; - }; - if let Some(Definition::Class(class_loc)) = - pkg_items_for_class.lookup_type(class_name.namespace(), class_name.name()) - { - let file = class_loc.file(self.context.db()); - let ns_context = - baml_compiler2_hir::file_package::file_package(self.context.db(), file) - .namespace_path; - let item_tree = baml_compiler2_ppir::file_item_tree(self.context.db(), file); - let class_data = &item_tree[class_loc.id(self.context.db())]; - for field in &class_data.fields { - let mut diags = Vec::new(); - let field_ty = field - .type_expr - .as_ref() - .map(|te| { - let ty = crate::lower_type_expr::lower_type_expr_in_ns( - self.context.db(), - &te.expr, - pkg_items_for_class, - &ns_context, - &class_data.generic_params, - &mut diags, - ); - for diag in diags.drain(..) { - self.context.report_at_span(diag, te.span); - } - ty - }) - .unwrap_or(Ty::Unknown { - attr: TyAttr::default(), - }); - result.insert(field.name.clone(), field_ty); - } - } - result + fn lookup_class_fields(&self, class_name: &crate::ty::NominalTypeRef) -> FxHashMap { + let db = self.context.db(); + self.res_ctx + .lookup_class_fields(db, class_name) + .into_iter() + .collect() } /// Fetch `PackageItems` for the package that owns a class type. @@ -3292,66 +3421,115 @@ impl<'db> TypeInferenceBuilder<'db> { } } - /// Look up a class method by name from the item tree. + /// Look up a class method by name. /// - /// Methods are stored on the `Class` entry directly (not in the package - /// namespace), so we resolve the class, iterate its method IDs, and match - /// by name. Returns the method type along with the class and function locs - /// so callers can record a `MemberResolution`. + /// For own-package classes, methods are resolved directly from the `ItemTree` + /// via `build_function_ty_from_signature`, which handles scope IDs, throw-set + /// keys, and inferred throws. + /// + /// For dependency-package classes, resolution is delegated to + /// `PackageResolutionContext::lookup_class_method`, which reads pre-resolved + /// type information from the `PackageInterface` and applies class-level generic + /// substitution. The `ClassLoc` and `FunctionLoc` are still resolved via the + /// raw item tree so callers can record a `MemberResolution`. fn lookup_class_method( &self, - class_name: &crate::ty::QualifiedTypeName, + class_name: &crate::ty::NominalTypeRef, method_name: &Name, ) -> Option<( Ty, baml_compiler2_hir::loc::ClassLoc<'db>, baml_compiler2_hir::loc::FunctionLoc<'db>, )> { + let db = self.context.db(); let pkg_items_for_class = self.resolve_class_pkg_items(class_name.package())?; let def = pkg_items_for_class.lookup_type(class_name.namespace(), class_name.name())?; let Definition::Class(class_loc) = def else { return None; }; - let db = self.context.db(); let file = class_loc.file(db); - let ns_context = baml_compiler2_hir::file_package::file_package(db, file).namespace_path; let item_tree = baml_compiler2_ppir::file_item_tree(db, file); let class_data = &item_tree[class_loc.id(db)]; - for &method_id in &class_data.methods { + // Resolve the FunctionLoc for the method (needed for MemberResolution regardless of path). + let func_loc = class_data.methods.iter().find_map(|&method_id| { let method_data = &item_tree[method_id]; if method_data.name == *method_name { - let mut all_generic_params = class_data.generic_params.clone(); - all_generic_params.extend(method_data.generic_params.iter().cloned()); - let func_loc = baml_compiler2_hir::loc::FunctionLoc::new(db, file, method_id); - let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); - let class_ty = Ty::Class(class_name.clone(), TyAttr::default()); - let function_key = crate::throw_inference::throw_set_key( - &ns_context, - &Name::new(format!("{}.{}", class_name.name(), method_data.name)), - ); - let method_scope = - self.find_function_scope_id(file, method_data.span, &method_data.name); - let method_body = baml_compiler2_hir::body::function_body(db, func_loc); - let ty = self.build_function_ty_from_signature( - pkg_items_for_class, - &ns_context, - &all_generic_params, - sig.as_ref(), - Some(&function_key), - Some(method_scope), - Some(PackageId::new( - db, - baml_compiler2_hir::file_package::file_package(db, file).package, - )), - Some(method_body.as_ref()), - Some(&class_ty), - ); - // Note: diags from method signatures are reported at definition site. - return Some((ty, class_loc, func_loc)); + Some(baml_compiler2_hir::loc::FunctionLoc::new( + db, file, method_id, + )) + } else { + None } + })?; + + let own_pkg_name = self.package_id.name(db); + if *class_name.package() == own_pkg_name { + // Own-package path: build function type from signature with scope/throw-set support. + let ns_context = + baml_compiler2_hir::file_package::file_package(db, file).namespace_path; + let class_bindings = + crate::generics::bind_type_vars(&class_data.generic_params, class_name.type_args()); + + for &method_id in &class_data.methods { + let method_data = &item_tree[method_id]; + if method_data.name == *method_name { + let mut all_generic_params = class_data.generic_params.clone(); + all_generic_params.extend(method_data.generic_params.iter().cloned()); + let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); + let class_ty = Ty::Class(class_name.clone(), TyAttr::default()); + let function_key = crate::throw_inference::throw_set_key( + &ns_context, + &Name::new(format!("{}.{}", class_name.name(), method_data.name)), + ); + let method_scope = + self.find_function_scope_id(file, method_data.span, &method_data.name); + let method_body = baml_compiler2_hir::body::function_body(db, func_loc); + let ty = self.build_function_ty_from_signature( + pkg_items_for_class, + &ns_context, + &all_generic_params, + sig.as_ref(), + Some(&function_key), + Some(method_scope), + Some(PackageId::new( + db, + baml_compiler2_hir::file_package::file_package(db, file).package, + )), + Some(method_body.as_ref()), + Some(&class_ty), + ); + // Note: diags from method signatures are reported at definition site. + return Some(( + crate::generics::substitute_ty(&ty, &class_bindings), + class_loc, + func_loc, + )); + } + } + None + } else { + // Dep-package path: delegate type-level resolution to PackageResolutionContext, + // which reads from the pre-resolved PackageInterface and applies class-level + // generic substitution. + let resolved = self + .res_ctx + .lookup_class_method(db, class_name, method_name)?; + let ty = Ty::Function { + params: resolved + .function + .params + .into_iter() + .map(|(n, ty)| (Some(n), ty)) + .collect(), + ret: Box::new(resolved.function.return_type), + throws: Box::new(resolved.function.throws.unwrap_or(Ty::Never { + attr: TyAttr::default(), + })), + attr: TyAttr::default(), + }; + Some((ty, class_loc, func_loc)) } - None } /// Check if a `FieldAccess` base is a primitive type name used for static @@ -3489,7 +3667,7 @@ impl<'db> TypeInferenceBuilder<'db> { || first.as_str() == self.package_id.name(db).as_str() { let class_qtn = crate::lower_type_expr::qualify_def(db, def, item_name); - let base_ty = Ty::Class(class_qtn, TyAttr::default()); + let base_ty = Ty::Class(class_qtn.into(), TyAttr::default()); return Some(self.resolve_member(&base_ty, member, at)); } let class_path: Vec<&str> = @@ -3499,13 +3677,13 @@ impl<'db> TypeInferenceBuilder<'db> { .or_else(|| { let class_qtn = crate::lower_type_expr::qualify_def(db, def, item_name); - let base_ty = Ty::Class(class_qtn, TyAttr::default()); + let base_ty = Ty::Class(class_qtn.into(), TyAttr::default()); Some(self.resolve_member(&base_ty, member, at)) }); } Definition::Enum(_) => { let enum_qtn = crate::lower_type_expr::qualify_def(db, def, item_name); - let base_ty = Ty::Enum(enum_qtn, TyAttr::default()); + let base_ty = Ty::Enum(enum_qtn.into(), TyAttr::default()); return Some(self.resolve_member(&base_ty, member, at)); } _ => {} @@ -3658,7 +3836,8 @@ impl<'db> TypeInferenceBuilder<'db> { pkg_info.namespace_path, class_data.name.clone(), class_data.generic_params.clone(), - ), + ) + .into(), TyAttr::default(), ) } else if type_args.len() == 1 { @@ -3672,7 +3851,8 @@ impl<'db> TypeInferenceBuilder<'db> { pkg_info.package, pkg_info.namespace_path, class_data.name.clone(), - ), + ) + .into(), TyAttr::default(), ) } @@ -3691,7 +3871,8 @@ impl<'db> TypeInferenceBuilder<'db> { pkg_info.package, pkg_info.namespace_path, class_data.name.clone(), - ), + ) + .into(), TyAttr::default(), ) } @@ -3703,7 +3884,8 @@ impl<'db> TypeInferenceBuilder<'db> { pkg_info.package, pkg_info.namespace_path, class_data.name.clone(), - ), + ) + .into(), TyAttr::default(), ) }; diff --git a/baml_language/crates/baml_compiler2_tir/src/generics.rs b/baml_language/crates/baml_compiler2_tir/src/generics.rs index 595775118c..5b7886632e 100644 --- a/baml_language/crates/baml_compiler2_tir/src/generics.rs +++ b/baml_language/crates/baml_compiler2_tir/src/generics.rs @@ -60,6 +60,36 @@ pub fn substitute_ty(ty: &Ty, bindings: &FxHashMap) -> Ty { } match ty { Ty::TypeVar(name, _) => bindings.get(name).cloned().unwrap_or_else(|| ty.clone()), + Ty::Class(nominal, attr) => Ty::Class( + nominal.with_type_args( + nominal + .type_args() + .iter() + .map(|arg| substitute_ty(arg, bindings)) + .collect(), + ), + attr.clone(), + ), + Ty::Enum(nominal, attr) => Ty::Enum( + nominal.with_type_args( + nominal + .type_args() + .iter() + .map(|arg| substitute_ty(arg, bindings)) + .collect(), + ), + attr.clone(), + ), + Ty::TypeAlias(nominal, attr) => Ty::TypeAlias( + nominal.with_type_args( + nominal + .type_args() + .iter() + .map(|arg| substitute_ty(arg, bindings)) + .collect(), + ), + attr.clone(), + ), Ty::List(inner, attr) => Ty::List(Box::new(substitute_ty(inner, bindings)), attr.clone()), Ty::Map(k, v, attr) => Ty::Map( Box::new(substitute_ty(k, bindings)), @@ -103,9 +133,11 @@ pub fn substitute_ty(ty: &Ty, bindings: &FxHashMap) -> Ty { /// intercept `T` references that would otherwise produce `Ty::Unknown`. fn substitute_type_expr(expr: &TypeExpr, bindings: &FxHashMap) -> Option { match expr { - TypeExpr::Path { segments, .. } if segments.len() == 1 => { - bindings.get(&segments[0]).cloned() - } + TypeExpr::Path { + segments, + type_args, + .. + } if segments.len() == 1 && type_args.is_empty() => bindings.get(&segments[0]).cloned(), _ => None, } } @@ -255,6 +287,39 @@ pub fn lower_type_expr_with_generics( ), attr: TyAttr::default(), }, + TypeExpr::Path { + segments, + type_args, + attrs, + } if !type_args.is_empty() => { + let base = lower_type_expr_in_ns( + db, + &TypeExpr::Path { + segments: segments.clone(), + type_args: vec![], + attrs: attrs.clone(), + }, + package_items, + ns_context, + &[], + diagnostics, + ); + base.with_nominal_type_args( + type_args + .iter() + .map(|arg| { + lower_type_expr_with_generics( + db, + arg, + package_items, + ns_context, + bindings, + diagnostics, + ) + }) + .collect(), + ) + } // For all other type expressions (primitives, multi-segment paths, etc.), // lower normally and then substitute in the result. other => { @@ -303,6 +368,9 @@ pub fn contains_typevar(ty: &Ty) -> bool { || contains_typevar(ret) || contains_typevar(throws) } + Ty::Class(qn, _) | Ty::Enum(qn, _) | Ty::TypeAlias(qn, _) => { + qn.type_args().iter().any(contains_typevar) + } _ => false, } } @@ -352,6 +420,32 @@ pub fn infer_bindings(formal: &Ty, actual: &Ty, bindings: &mut FxHashMap + { + for (f_arg, a_arg) in f_qn.type_args().iter().zip(a_qn.type_args().iter()) { + infer_bindings(f_arg, a_arg, bindings); + } + } + (Ty::Enum(f_qn, _), Ty::Enum(a_qn, _)) + if f_qn.qtn() == a_qn.qtn() && f_qn.type_args().len() == a_qn.type_args().len() => + { + for (f_arg, a_arg) in f_qn.type_args().iter().zip(a_qn.type_args().iter()) { + infer_bindings(f_arg, a_arg, bindings); + } + } + // TypeAlias: if both sides have the same alias identity and arity, recurse into type_args. + // Full alias expansion before inference is the caller's responsibility. + // This arm handles the structural case; callers should pre-resolve aliases + // via resolve_alias_chain when possible. + (Ty::TypeAlias(f_qn, _), Ty::TypeAlias(a_qn, _)) + if f_qn.qtn() == a_qn.qtn() && f_qn.type_args().len() == a_qn.type_args().len() => + { + for (f_arg, a_arg) in f_qn.type_args().iter().zip(a_qn.type_args().iter()) { + infer_bindings(f_arg, a_arg, bindings); + } + } _ => {} // Concrete types: nothing to infer } } @@ -443,6 +537,42 @@ pub fn erase_unresolved_typevars( .collect(), attr.clone(), ), + Ty::Class(nominal, attr) => { + let new_args: Vec = nominal + .type_args() + .iter() + .map(|a| erase_unresolved_typevars(a, diagnostics)) + .collect(); + if new_args.is_empty() { + ty.clone() + } else { + Ty::Class(nominal.with_type_args(new_args), attr.clone()) + } + } + Ty::Enum(nominal, attr) => { + let new_args: Vec = nominal + .type_args() + .iter() + .map(|a| erase_unresolved_typevars(a, diagnostics)) + .collect(); + if new_args.is_empty() { + ty.clone() + } else { + Ty::Enum(nominal.with_type_args(new_args), attr.clone()) + } + } + Ty::TypeAlias(nominal, attr) => { + let new_args: Vec = nominal + .type_args() + .iter() + .map(|a| erase_unresolved_typevars(a, diagnostics)) + .collect(); + if new_args.is_empty() { + ty.clone() + } else { + Ty::TypeAlias(nominal.with_type_args(new_args), attr.clone()) + } + } other => other.clone(), } } diff --git a/baml_language/crates/baml_compiler2_tir/src/inference.rs b/baml_language/crates/baml_compiler2_tir/src/inference.rs index 0d3848e2e5..0facee0462 100644 --- a/baml_language/crates/baml_compiler2_tir/src/inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/inference.rs @@ -354,7 +354,7 @@ pub fn infer_scope_types<'db>( let ns_path = &pkg_info.namespace_path; pkg_items.lookup_type(ns_path, cn).map(|def| { Ty::Class( - crate::lower_type_expr::qualify_def(db, def, cn), + crate::lower_type_expr::qualify_def(db, def, cn).into(), TyAttr::default(), ) }) diff --git a/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs b/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs index a46ae8a7ff..4c0e736c38 100644 --- a/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs +++ b/baml_language/crates/baml_compiler2_tir/src/lower_type_expr.rs @@ -168,7 +168,11 @@ pub fn lower_type_expr_in_ns( diagnostics: &mut Vec, ) -> Ty { match type_expr { - TypeExpr::Path { segments, .. } => { + TypeExpr::Path { + segments, + type_args, + .. + } => { let item = segments.last().expect("non-empty path"); let seg_ns = &segments[..segments.len() - 1]; // When we have a namespace context, try the qualified path first. @@ -200,13 +204,31 @@ pub fn lower_type_expr_in_ns( if let Some(def) = resolved { let short = segments.last().expect("non-empty path"); + let lowered_type_args: Vec = type_args + .iter() + .map(|arg| { + lower_type_expr_in_ns( + db, + arg, + package_items, + ns_context, + generic_params, + diagnostics, + ) + }) + .collect(); match def { Definition::Class(_) => { - Ty::Class(qualify_def(db, def, short), TyAttr::default()) + Ty::Class(qualify_def(db, def, short).into(), TyAttr::default()) + .with_nominal_type_args(lowered_type_args) + } + Definition::Enum(_) => { + Ty::Enum(qualify_def(db, def, short).into(), TyAttr::default()) + .with_nominal_type_args(lowered_type_args) } - Definition::Enum(_) => Ty::Enum(qualify_def(db, def, short), TyAttr::default()), Definition::TypeAlias(_) => { - Ty::TypeAlias(qualify_def(db, def, short), TyAttr::default()) + Ty::TypeAlias(qualify_def(db, def, short).into(), TyAttr::default()) + .with_nominal_type_args(lowered_type_args) } // Let bindings are values, not types — produce Unknown in a type position. _ => Ty::Unknown { @@ -215,7 +237,7 @@ pub fn lower_type_expr_in_ns( } } else { // Check if this is a generic type parameter (e.g. T, K, V). - if segments.len() == 1 { + if segments.len() == 1 && type_args.is_empty() { if generic_params.iter().any(|p| *p == segments[0]) { return Ty::TypeVar(segments[0].clone(), TyAttr::default()); } diff --git a/baml_language/crates/baml_compiler2_tir/src/normalize.rs b/baml_language/crates/baml_compiler2_tir/src/normalize.rs index d3e8771aec..63bd38e607 100644 --- a/baml_language/crates/baml_compiler2_tir/src/normalize.rs +++ b/baml_language/crates/baml_compiler2_tir/src/normalize.rs @@ -37,6 +37,30 @@ pub fn find_recursive_aliases( recursive } +/// Substitute use-site `type_args` into an alias body. +/// +/// If the alias has no generic params or the use-site has no `type_args`, +/// returns the body unchanged. Otherwise applies `bind_type_vars` + `substitute_ty`. +/// +/// `generic_params` must be sourced from alias definition metadata (item tree / +/// `ExportedType`), NOT from `QualifiedTypeName.generic_params` which is unreliable. +pub(crate) fn instantiate_alias( + _qtn: &QualifiedTypeName, + alias_body: &Ty, + use_site_type_args: &[Ty], + aliases: &HashMap, +) -> Ty { + if use_site_type_args.is_empty() { + return alias_body.clone(); + } + // Current user-defined type aliases have no generic params in AST/HIR, + // so this path is not yet reachable for user code. When generic type aliases + // are added, the caller must supply generic_params from alias definition metadata. + // For now, return the body unchanged. + let _ = aliases; + alias_body.clone() +} + // ═══════════════════════════════════════════════════════════════════════════ // STRUCTURAL TYPE (private) // ═══════════════════════════════════════════════════════════════════════════ @@ -58,8 +82,8 @@ enum StructuralTy { // Literal Literal(baml_base::Literal), // User-defined (resolved by qualified name) - Class(QualifiedTypeName), - Enum(QualifiedTypeName), + Class(QualifiedTypeName, Vec), + Enum(QualifiedTypeName, Vec), EnumVariant(QualifiedTypeName, Name), // Constructors Optional(Box), @@ -227,8 +251,22 @@ impl StructuralTy { (StructuralTy::Literal(LiteralValue::String(_)), StructuralTy::String) => true, (StructuralTy::Literal(LiteralValue::Bool(_)), StructuralTy::Bool) => true, - // EnumVariant(E, V) <: Enum(E) - (StructuralTy::EnumVariant(e, _), StructuralTy::Enum(sup_e)) => e == sup_e, + // Nominal class: same declaration identity + invariant type_args + (StructuralTy::Class(qtn1, args1), StructuralTy::Class(qtn2, args2)) + if qtn1 == qtn2 && args1.len() == args2.len() => + { + args1.iter().zip(args2.iter()).all(|(a, b)| a == b) + } + + // Nominal enum: same declaration identity + invariant type_args + (StructuralTy::Enum(qtn1, args1), StructuralTy::Enum(qtn2, args2)) + if qtn1 == qtn2 && args1.len() == args2.len() => + { + args1.iter().zip(args2.iter()).all(|(a, b)| a == b) + } + + // EnumVariant(E, V) <: Enum(E) — match new Enum shape with trailing _ for type_args + (StructuralTy::EnumVariant(e, _), StructuralTy::Enum(sup_e, _)) => e == sup_e, // Function subtyping: contravariant params, covariant return and throws ( @@ -309,6 +347,18 @@ fn substitute( var: v.clone(), body: Box::new(substitute(body, var, replacement)), }, + StructuralTy::Class(qtn, args) => StructuralTy::Class( + qtn.clone(), + args.iter() + .map(|a| substitute(a, var, replacement)) + .collect(), + ), + StructuralTy::Enum(qtn, args) => StructuralTy::Enum( + qtn.clone(), + args.iter() + .map(|a| substitute(a, var, replacement)) + .collect(), + ), _ => ty.clone(), } } @@ -351,26 +401,41 @@ fn normalize_impl( Ty::Unknown { .. } => StructuralTy::Unknown, Ty::Error { .. } => StructuralTy::Error, Ty::Literal(lit, _freshness, _) => StructuralTy::Literal(lit.clone()), - Ty::Class(qn, _) => StructuralTy::Class(qn.clone()), - Ty::Enum(qn, _) => StructuralTy::Enum(qn.clone()), + Ty::Class(qn, _) => StructuralTy::Class( + qn.qtn().clone(), + qn.type_args() + .iter() + .map(|arg| normalize_impl(arg, aliases, recursive, expanding)) + .collect(), + ), + Ty::Enum(qn, _) => StructuralTy::Enum( + qn.qtn().clone(), + qn.type_args() + .iter() + .map(|arg| normalize_impl(arg, aliases, recursive, expanding)) + .collect(), + ), Ty::EnumVariant(qn, v, _) => StructuralTy::EnumVariant(qn.clone(), v.clone()), Ty::TypeAlias(qn, _) => { - if expanding.contains(qn) { - return StructuralTy::TyVar(qn.clone()); + if expanding.contains(qn.qtn()) { + return StructuralTy::TyVar(qn.qtn().clone()); } - if let Some(alias_ty) = aliases.get(qn) { - if recursive.contains(qn) { - expanding.insert(qn.clone()); - let body = normalize_impl(alias_ty, aliases, recursive, expanding); - expanding.remove(qn); + if let Some(alias_ty) = aliases.get(qn.qtn()) { + // Substitute use-site type_args into alias body + let alias_ty = alias_ty.clone(); + let substituted = instantiate_alias(qn.qtn(), &alias_ty, qn.type_args(), aliases); + if recursive.contains(qn.qtn()) { + expanding.insert(qn.qtn().clone()); + let body = normalize_impl(&substituted, aliases, recursive, expanding); + expanding.remove(qn.qtn()); StructuralTy::Mu { - var: qn.clone(), + var: qn.qtn().clone(), body: Box::new(body), } } else { - normalize_impl(alias_ty, aliases, recursive, expanding) + normalize_impl(&substituted, aliases, recursive, expanding) } } else { StructuralTy::Error @@ -447,7 +512,9 @@ fn ty_has_cycle( stack: &mut HashSet, ) -> bool { match ty { - Ty::TypeAlias(qn, _) if aliases.contains_key(qn) => has_cycle(qn, aliases, visited, stack), + Ty::TypeAlias(qn, _) if aliases.contains_key(qn.qtn()) => { + has_cycle(qn.qtn(), aliases, visited, stack) + } Ty::Optional(inner, _) | Ty::List(inner, _) | Ty::EvolvingList(inner, _) => { ty_has_cycle(inner, aliases, visited, stack) } @@ -569,11 +636,21 @@ fn extract_type_alias_deps( in_structural: bool, ) { match ty { - Ty::TypeAlias(qn, _) if aliases.contains_key(qn) => { + Ty::TypeAlias(qn, _) if aliases.contains_key(qn.qtn()) => { if in_structural { - structural.insert(qn.clone()); + structural.insert(qn.qtn().clone()); } else { - non_structural.insert(qn.clone()); + non_structural.insert(qn.qtn().clone()); + } + // Recurse into type_args so alias dependencies hidden inside them are found + for arg in qn.type_args() { + visit(arg, aliases, non_structural, structural, in_structural); + } + } + Ty::Class(qn, _) | Ty::Enum(qn, _) if !qn.type_args().is_empty() => { + // Recurse into type_args to find alias dependencies inside them + for arg in qn.type_args() { + visit(arg, aliases, non_structural, structural, in_structural); } } Ty::Optional(inner, _) => { @@ -834,15 +911,15 @@ fn extract_required_class_deps( match ty { Ty::Class(qn, _) => { // Only add if the field is truly required - if !optional && !in_list_or_map && class_fields.contains_key(qn) { - deps.insert(qn.clone()); + if !optional && !in_list_or_map && class_fields.contains_key(qn.qtn()) { + deps.insert(qn.qtn().clone()); } } Ty::TypeAlias(qn, _) => { // Resolve through type aliases (only if still required context) - if !optional && !in_list_or_map && !visiting.contains(qn) { - if let Some(alias_ty) = type_aliases.get(qn) { - visiting.insert(qn.clone()); + if !optional && !in_list_or_map && !visiting.contains(qn.qtn()) { + if let Some(alias_ty) = type_aliases.get(qn.qtn()) { + visiting.insert(qn.qtn().clone()); extract_required_class_deps( alias_ty, class_fields, @@ -852,7 +929,7 @@ fn extract_required_class_deps( in_list_or_map, visiting, ); - visiting.remove(qn); + visiting.remove(qn.qtn()); } } } @@ -951,7 +1028,7 @@ mod tests { } fn type_alias(name: &str) -> Ty { - Ty::TypeAlias(qn(name), TyAttr::default()) + Ty::TypeAlias(qn(name).into(), TyAttr::default()) } #[test] @@ -1078,7 +1155,7 @@ mod tests { &Ty::Never { attr: TyAttr::default() }, - &Ty::Class(qn("Foo"), TyAttr::default()), + &Ty::Class(qn("Foo").into(), TyAttr::default()), &aliases )); assert!(is_subtype_of( @@ -1149,12 +1226,12 @@ mod tests { let aliases = HashMap::new(); assert!(is_subtype_of( &Ty::EnumVariant(qn("Color"), Name::new("Red"), TyAttr::default()), - &Ty::Enum(qn("Color"), TyAttr::default()), + &Ty::Enum(qn("Color").into(), TyAttr::default()), &aliases )); assert!(!is_subtype_of( &Ty::EnumVariant(qn("Color"), Name::new("Red"), TyAttr::default()), - &Ty::Enum(qn("Shape"), TyAttr::default()), + &Ty::Enum(qn("Shape").into(), TyAttr::default()), &aliases )); } diff --git a/baml_language/crates/baml_compiler2_tir/src/package_interface.rs b/baml_language/crates/baml_compiler2_tir/src/package_interface.rs index 580ce355a7..06f7d04914 100644 --- a/baml_language/crates/baml_compiler2_tir/src/package_interface.rs +++ b/baml_language/crates/baml_compiler2_tir/src/package_interface.rs @@ -20,7 +20,7 @@ use rustc_hash::FxHashMap; use crate::{ lower_type_expr::{lower_type_expr_in_ns, qualify_def}, throw_inference::{FunctionThrowSets, function_throw_sets}, - ty::{QualifiedTypeName, Ty, TyAttr}, + ty::{NominalTypeRef, QualifiedTypeName, Ty, TyAttr}, }; // ── Data types ───────────────────────────────────────────────────────────── @@ -174,9 +174,11 @@ impl ExportedType { /// Convert to a Ty (for type resolution results). pub fn to_ty(&self) -> Ty { match self { - ExportedType::Class { qtn, .. } => Ty::Class(qtn.clone(), TyAttr::default()), - ExportedType::Enum { qtn, .. } => Ty::Enum(qtn.clone(), TyAttr::default()), - ExportedType::TypeAlias { qtn, .. } => Ty::TypeAlias(qtn.clone(), TyAttr::default()), + ExportedType::Class { qtn, .. } => Ty::Class(qtn.clone().into(), TyAttr::default()), + ExportedType::Enum { qtn, .. } => Ty::Enum(qtn.clone().into(), TyAttr::default()), + ExportedType::TypeAlias { qtn, .. } => { + Ty::TypeAlias(qtn.clone().into(), TyAttr::default()) + } } } } @@ -456,7 +458,7 @@ fn build_self_type_for_class( class_data.name.clone(), class_data.generic_params.clone(), ); - Ty::Class(qtn, TyAttr::default()) + Ty::Class(qtn.into(), TyAttr::default()) } } } @@ -637,41 +639,88 @@ impl<'db> PackageResolutionContext<'db> { } /// Look up class fields. Dual dispatch: - /// - Own-package: `ItemTree` -> lower fields - /// - Dependency: `ExportedType::Class` { fields } + /// - Own-package: `ItemTree` -> lower fields + apply class-level generic substitution + /// - Dependency: `ExportedType::Class` { fields } + apply class-level generic substitution pub fn lookup_class_fields( &self, db: &'db dyn crate::Db, - class_name: &QualifiedTypeName, + class_name: &NominalTypeRef, ) -> Vec<(Name, Ty)> { let class_pkg = class_name.package(); if class_pkg.as_str() == self.own_package_name.as_str() { - self.lookup_own_class_fields(db, class_name) + let raw_fields = self.lookup_own_class_fields(db, class_name.qtn()); + self.apply_class_substitution(db, class_name, raw_fields) } else { for (dep_name, dep_iface) in &self.dep_interfaces { if dep_name != class_pkg { continue; } - if let Some(ExportedType::Class { fields, .. }) = - dep_iface.lookup_type(class_name.namespace(), class_name.name()) + if let Some(ExportedType::Class { + fields, + generic_params, + .. + }) = dep_iface.lookup_type(class_name.namespace(), class_name.name()) { - return fields.clone(); + if class_name.type_args().is_empty() || generic_params.is_empty() { + return fields.clone(); + } + let bindings = + crate::generics::bind_type_vars(generic_params, class_name.type_args()); + return fields + .iter() + .map(|(n, ty)| (n.clone(), crate::generics::substitute_ty(ty, &bindings))) + .collect(); } } Vec::new() } } + /// Apply class-level generic substitution to own-package fields. + fn apply_class_substitution( + &self, + db: &'db dyn crate::Db, + class_name: &NominalTypeRef, + raw_fields: Vec<(Name, Ty)>, + ) -> Vec<(Name, Ty)> { + if class_name.type_args().is_empty() { + return raw_fields; + } + // Get generic_params from item tree + let Some(def) = self + .own_items + .lookup_type(class_name.namespace(), class_name.name()) + else { + return raw_fields; + }; + let Definition::Class(class_loc) = def else { + return raw_fields; + }; + let item_tree = file_item_tree(db, class_loc.file(db)); + let class_data = &item_tree[class_loc.id(db)]; + let bindings = + crate::generics::bind_type_vars(&class_data.generic_params, class_name.type_args()); + raw_fields + .into_iter() + .map(|(n, ty)| (n, crate::generics::substitute_ty(&ty, &bindings))) + .collect() + } + /// Look up a class method. Dual dispatch. + /// Applies class-level generic substitution to the returned method type. pub fn lookup_class_method( &self, db: &'db dyn crate::Db, - class_name: &QualifiedTypeName, + class_name: &NominalTypeRef, method_name: &Name, ) -> Option { let class_pkg = class_name.package(); if class_pkg.as_str() == self.own_package_name.as_str() { - self.lookup_own_class_method(db, class_name, method_name) + self.lookup_own_class_method(db, class_name.qtn(), method_name) + .map(|mut rm| { + Self::apply_method_substitution(class_name, &mut rm); + rm + }) } else { for (dep_name, dep_iface) in &self.dep_interfaces { if dep_name != class_pkg { @@ -684,7 +733,7 @@ impl<'db> PackageResolutionContext<'db> { }) = dep_iface.lookup_type(class_name.namespace(), class_name.name()) { if let Some(method) = methods.iter().find(|m| &m.name == method_name) { - return Some(ResolvedMethod { + let mut resolved_method = ResolvedMethod { function: ResolvedFunction { name: method.name.clone(), params: method.params.clone(), @@ -695,7 +744,29 @@ impl<'db> PackageResolutionContext<'db> { }, class_name: class_name.name().clone(), class_generic_params: generic_params.clone(), - }); + }; + // Apply class-level substitution using generic_params from ExportedType + if !class_name.type_args().is_empty() && !generic_params.is_empty() { + let bindings = crate::generics::bind_type_vars( + generic_params, + class_name.type_args(), + ); + resolved_method.function.params = resolved_method + .function + .params + .into_iter() + .map(|(n, ty)| (n, crate::generics::substitute_ty(&ty, &bindings))) + .collect(); + resolved_method.function.return_type = crate::generics::substitute_ty( + &resolved_method.function.return_type, + &bindings, + ); + if let Some(ref throws) = resolved_method.function.throws.clone() { + resolved_method.function.throws = + Some(crate::generics::substitute_ty(throws, &bindings)); + } + } + return Some(resolved_method); } } } @@ -703,6 +774,37 @@ impl<'db> PackageResolutionContext<'db> { } } + /// Apply class-level generic substitution to a `ResolvedMethod`'s function signature. + /// Uses `generic_params` from the own-package item tree. + fn apply_method_substitution( + class_name: &NominalTypeRef, + resolved_method: &mut ResolvedMethod, + ) { + if class_name.type_args().is_empty() { + return; + } + // class_generic_params was set by lookup_own_class_method from item tree + if resolved_method.class_generic_params.is_empty() { + return; + } + let bindings = crate::generics::bind_type_vars( + &resolved_method.class_generic_params, + class_name.type_args(), + ); + resolved_method.function.params = resolved_method + .function + .params + .iter() + .map(|(n, ty)| (n.clone(), crate::generics::substitute_ty(ty, &bindings))) + .collect(); + resolved_method.function.return_type = + crate::generics::substitute_ty(&resolved_method.function.return_type, &bindings); + if let Some(ref throws) = resolved_method.function.throws.clone() { + resolved_method.function.throws = + Some(crate::generics::substitute_ty(throws, &bindings)); + } + } + fn lookup_own_class_fields( &self, db: &'db dyn crate::Db, @@ -853,9 +955,11 @@ fn def_to_ty<'db>(db: &'db dyn crate::Db, def: Definition<'db>) -> Ty { } }; match def { - Definition::Class(_) => Ty::Class(qualify_def(db, def, &name), TyAttr::default()), - Definition::Enum(_) => Ty::Enum(qualify_def(db, def, &name), TyAttr::default()), - Definition::TypeAlias(_) => Ty::TypeAlias(qualify_def(db, def, &name), TyAttr::default()), + Definition::Class(_) => Ty::Class(qualify_def(db, def, &name).into(), TyAttr::default()), + Definition::Enum(_) => Ty::Enum(qualify_def(db, def, &name).into(), TyAttr::default()), + Definition::TypeAlias(_) => { + Ty::TypeAlias(qualify_def(db, def, &name).into(), TyAttr::default()) + } _ => Ty::Unknown { attr: TyAttr::default(), }, diff --git a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs index 697136edce..f9188f8aad 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs @@ -387,9 +387,11 @@ fn throw_fact_from_expr<'db>( if let Some(def) = pkg_items.lookup_type(ns_context, name) { match def { Definition::Class(_) => { - Ty::Class(qualify_def(db, def, name), TyAttr::default()) + Ty::Class(qualify_def(db, def, name).into(), TyAttr::default()) + } + Definition::Enum(_) => { + Ty::Enum(qualify_def(db, def, name).into(), TyAttr::default()) } - Definition::Enum(_) => Ty::Enum(qualify_def(db, def, name), TyAttr::default()), _ => Ty::Unknown { attr: TyAttr::default(), }, @@ -463,10 +465,10 @@ fn resolve_path_to_ty<'db>( }; if let Some(def) = def { return match def { - Definition::Class(_) => Ty::Class(qualify_def(db, def, name), TyAttr::default()), - Definition::Enum(_) => Ty::Enum(qualify_def(db, def, name), TyAttr::default()), + Definition::Class(_) => Ty::Class(qualify_def(db, def, name).into(), TyAttr::default()), + Definition::Enum(_) => Ty::Enum(qualify_def(db, def, name).into(), TyAttr::default()), Definition::TypeAlias(_) => { - Ty::TypeAlias(qualify_def(db, def, name), TyAttr::default()) + Ty::TypeAlias(qualify_def(db, def, name).into(), TyAttr::default()) } _ => Ty::Unknown { attr: TyAttr::default(), diff --git a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs index 3d3fccc642..8301415083 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throws_semantics.rs @@ -17,8 +17,15 @@ pub(crate) fn resolve_alias_chain(ty: &Ty, aliases: &HashMap match aliases.get(qtn) { - Some(expanded) => resolved = expanded.clone(), + Ty::TypeAlias(qn, _) => match aliases.get(qn.qtn()) { + Some(expanded) => { + resolved = crate::normalize::instantiate_alias( + qn.qtn(), + expanded, + qn.type_args(), + aliases, + ); + } None => break, }, _ => break, diff --git a/baml_language/crates/baml_compiler2_tir/src/ty.rs b/baml_language/crates/baml_compiler2_tir/src/ty.rs index c716f7b3bd..3032e7cc33 100644 --- a/baml_language/crates/baml_compiler2_tir/src/ty.rs +++ b/baml_language/crates/baml_compiler2_tir/src/ty.rs @@ -1,6 +1,6 @@ //! Resolved type representation — the output of type resolution. -use std::fmt; +use std::{fmt, ops::Deref}; use baml_base::Name; pub use baml_base::attr::TyAttr; @@ -18,8 +18,11 @@ pub struct QualifiedTypeName { namespace: Vec, /// The short/local name of the type (e.g. `"Foo"`). name: Name, - /// Unresolved generic type parameter names (e.g. `["T"]` for `Array`). - /// Empty for non-generic types or when concrete type args are substituted. + /// Unresolved generic type parameter names on the definition itself + /// (e.g. `["T"]` for the declared class `Array`). + /// + /// Concrete instantiations like `Array` are represented on + /// `NominalTypeRef::type_args`, not here. pub generic_params: Vec, } @@ -94,18 +97,106 @@ impl fmt::Display for QualifiedTypeName { } } +/// A nominal type reference with optional concrete type arguments. +/// +/// `QualifiedTypeName` identifies the declaration; `type_args` capture an +/// instantiated use like `Handler`. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct NominalTypeRef { + qtn: QualifiedTypeName, + type_args: Vec, +} + +impl NominalTypeRef { + pub fn new(qtn: QualifiedTypeName) -> Self { + Self { + qtn, + type_args: Vec::new(), + } + } + + pub fn new_with_type_args(qtn: QualifiedTypeName, type_args: Vec) -> Self { + Self { qtn, type_args } + } + + pub fn qtn(&self) -> &QualifiedTypeName { + &self.qtn + } + + pub fn package(&self) -> &Name { + self.qtn.package() + } + + pub fn namespace(&self) -> &Vec { + self.qtn.namespace() + } + + pub fn name(&self) -> &Name { + self.qtn.name() + } + + pub fn type_args(&self) -> &[Ty] { + &self.type_args + } + + pub fn into_qtn(self) -> QualifiedTypeName { + self.qtn + } + + #[must_use] + pub fn with_type_args(&self, type_args: Vec) -> Self { + Self { + qtn: self.qtn.clone(), + type_args, + } + } + + pub fn is_panic_type(&self) -> bool { + self.qtn.is_panic_type() + } +} + +impl Deref for NominalTypeRef { + type Target = QualifiedTypeName; + + fn deref(&self) -> &Self::Target { + &self.qtn + } +} + +impl From for NominalTypeRef { + fn from(qtn: QualifiedTypeName) -> Self { + Self::new(qtn) + } +} + +impl fmt::Display for NominalTypeRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.qtn)?; + if !self.type_args.is_empty() { + let args: Vec<_> = self + .type_args + .iter() + .map(std::string::ToString::to_string) + .collect(); + write!(f, "<{}>", args.join(", "))?; + } + Ok(()) + } +} + /// Resolved type — the output of type resolution (Pass 2). #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum Ty { /// A class type — just the name, no expansion. - Class(QualifiedTypeName, TyAttr), + Class(NominalTypeRef, TyAttr), /// An enum type. - Enum(QualifiedTypeName, TyAttr), + Enum(NominalTypeRef, TyAttr), /// An enum variant — Enum(qualified) . Variant(name). EnumVariant(QualifiedTypeName, Name, TyAttr), /// A type alias — opaque name reference, NOT expanded. /// Expansion happens lazily at subtype-checking time. - TypeAlias(QualifiedTypeName, TyAttr), + TypeAlias(NominalTypeRef, TyAttr), /// Primitive types. Primitive(PrimitiveType, TyAttr), /// T[] @@ -379,6 +470,20 @@ impl Ty { self } + #[must_use] + pub fn with_nominal_type_args(self, type_args: Vec) -> Ty { + if type_args.is_empty() { + return self; + } + + match self { + Ty::Class(nominal, attr) => Ty::Class(nominal.with_type_args(type_args), attr), + Ty::Enum(nominal, attr) => Ty::Enum(nominal.with_type_args(type_args), attr), + Ty::TypeAlias(nominal, attr) => Ty::TypeAlias(nominal.with_type_args(type_args), attr), + other => other, + } + } + /// Widen fresh literal types to their base primitive. /// /// Called at mutable binding sites (`let` without annotation). diff --git a/baml_language/crates/baml_compiler2_visualization/src/control_flow/from_ast.rs b/baml_language/crates/baml_compiler2_visualization/src/control_flow/from_ast.rs index fb87efff32..0d6998998b 100644 --- a/baml_language/crates/baml_compiler2_visualization/src/control_flow/from_ast.rs +++ b/baml_language/crates/baml_compiler2_visualization/src/control_flow/from_ast.rs @@ -1263,6 +1263,7 @@ mod tests { type_name: Some("MyResponse".into()), fields: vec![("ok".into(), field_val)], spreads: vec![], + type_args: vec![], }); let ret = stmts.alloc(ast::Stmt::Return(Some(obj))); Some(exprs.alloc(ast::Expr::Block { @@ -1324,6 +1325,7 @@ mod tests { type_name: Some("Result".into()), fields: vec![], spreads: vec![], + type_args: vec![], }); let ret_true = stmts.alloc(ast::Stmt::Return(Some(obj_true))); let then_b = exprs.alloc(ast::Expr::Block { @@ -1336,6 +1338,7 @@ mod tests { type_name: Some("Result".into()), fields: vec![("err".into(), err_val)], spreads: vec![], + type_args: vec![], }); let ret_false = stmts.alloc(ast::Stmt::Return(Some(obj_false))); let else_b = exprs.alloc(ast::Expr::Block { @@ -1375,6 +1378,7 @@ mod tests { type_name: Some("Resp".into()), fields: vec![("ok".into(), field_val)], spreads: vec![], + type_args: vec![], }); let body = ast::ExprBody { diff --git a/baml_language/crates/baml_lsp2_actions/src/completions.rs b/baml_language/crates/baml_lsp2_actions/src/completions.rs index 8c7f7c9abf..438ab534c3 100644 --- a/baml_language/crates/baml_lsp2_actions/src/completions.rs +++ b/baml_language/crates/baml_lsp2_actions/src/completions.rs @@ -655,7 +655,8 @@ fn definition_to_ty(db: &dyn Db, def: Definition<'_>) -> Option { pkg_info.package, pkg_info.namespace_path, class.name.clone(), - ), + ) + .into(), TyAttr::default(), )) } @@ -668,7 +669,8 @@ fn definition_to_ty(db: &dyn Db, def: Definition<'_>) -> Option { pkg_info.package, pkg_info.namespace_path, enum_data.name.clone(), - ), + ) + .into(), TyAttr::default(), )) } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/generic_function_inference.baml b/baml_language/crates/baml_tests/projects/function_type_throws/generic_function_inference.baml new file mode 100644 index 0000000000..c5b9ade443 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/generic_function_inference.baml @@ -0,0 +1,18 @@ +// === Generic function inference over nominal wrappers === +// T should be inferred from the nominal wrapper's type_args. + +class Wrapper { + inner: T +} + +function unwrap(w: Wrapper) -> T { + w.inner +} + +function test_unwrap_int() -> int { + unwrap(Wrapper { inner: 42 }) +} + +function test_unwrap_string() -> string { + unwrap(Wrapper { inner: "hello" }) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/generic_subtype_rejection.baml b/baml_language/crates/baml_tests/projects/function_type_throws/generic_subtype_rejection.baml new file mode 100644 index 0000000000..6122f84d03 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/generic_subtype_rejection.baml @@ -0,0 +1,15 @@ +// === Generic subtype rejection === +// Mismatched instantiations of the same class must be rejected. + +class Box { + value: T +} + +function take_string_box(b: Box) -> null { + null +} + +// Should produce a type error: Box is not assignable to Box +function bad_box_mismatch() -> null { + take_string_box(Box { value: 1 }) +} diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap index 24a192fa1f..7a3645e28e 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap @@ -421,7 +421,7 @@ fn baml.llm.Client.build_attempt(self: baml.llm.Client) -> baml.llm.Orchestratio let _0: baml.llm.OrchestrationStep[] // _0 // return let _1: baml.llm.Client // self // param let _2: baml.llm.PlannerState // planner_state - let _3: void[] + let _3: string[] let _4: baml.llm.PlannerState bb0: { @@ -605,7 +605,7 @@ fn baml.llm.Client.build_plan(self: baml.llm.Client) -> baml.llm.OrchestrationSt let _0: baml.llm.OrchestrationStep[] // _0 // return let _1: baml.llm.Client // self // param let _2: baml.llm.PlannerState // planner_state - let _3: void[] + let _3: string[] let _4: baml.llm.PlannerState bb0: { diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap index f43001d355..db95de141d 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap @@ -36,8 +36,8 @@ fn testing.TestRegistry.expand_set(self: testing.TestRegistry, name: string) -> let _12: testing.TestSetRegistration let _13: int // start let _14: testing.TestCollector // sub_collector - let _15: void[] - let _16: void[] + let _15: testing.TestRegistration[] + let _16: testing.TestSetRegistration[] let _17: null let _18: (testing.TestCollector) -> null throws unknown let _19: testing.TestSetRegistration @@ -47,7 +47,7 @@ fn testing.TestRegistry.expand_set(self: testing.TestRegistry, name: string) -> let _23: int let _24: testing.TestRegistry // sub_registry let _25: testing.TestCollector - let _26: map + let _26: map let _27: int let _28: null let _29: map diff --git a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap index f634debc56..0eed641ea9 100644 --- a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap @@ -670,7 +670,7 @@ function user.DeepFnReturn() -> (x: int) -> (y: int) -> (z: int) -> (w: int) -> } !! 2354..2355: type mismatch: expected (x: int) -> (y: int) -> (z: int) -> (w: int) -> map, got 1 } -function user.DeepParamTypes(nested: map>>, callback: (x: int) -> (y: int) -> (z: int) -> string throws __throws_callback) -> string throws __throws_callback { +function user.DeepParamTypes(nested: map>>, callback: (x: int) -> (y: int) -> (z: int) -> string throws __throws_callback) -> string throws never { { : "done" "done" : "done" } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_function_inference.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_function_inference.snap new file mode 100644 index 0000000000..3a29b3a338 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_function_inference.snap @@ -0,0 +1,101 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Generic" +Function "function" +Word "inference" +Word "over" +Word "nominal" +Word "wrappers" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "T" +Word "should" +Word "be" +Word "inferred" +Word "from" +Word "the" +Word "nominal" +Word "wrapper" +Error "'" +Word "s" +Word "type_args" +Dot "." +Class "class" +Word "Wrapper" +Less "<" +Word "T" +Greater ">" +LBrace "{" +Word "inner" +Colon ":" +Word "T" +RBrace "}" +Function "function" +Word "unwrap" +Less "<" +Word "T" +Greater ">" +LParen "(" +Word "w" +Colon ":" +Word "Wrapper" +Less "<" +Word "T" +Greater ">" +RParen ")" +Arrow "->" +Word "T" +LBrace "{" +Word "w" +Dot "." +Word "inner" +RBrace "}" +Function "function" +Word "test_unwrap_int" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "unwrap" +LParen "(" +Word "Wrapper" +Less "<" +Word "int" +Greater ">" +LBrace "{" +Word "inner" +Colon ":" +IntegerLiteral "42" +RBrace "}" +RParen ")" +RBrace "}" +Function "function" +Word "test_unwrap_string" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Word "unwrap" +LParen "(" +Word "Wrapper" +Less "<" +Word "string" +Greater ">" +LBrace "{" +Word "inner" +Colon ":" +Quote "\"" +Word "hello" +Quote "\"" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_subtype_rejection.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_subtype_rejection.snap new file mode 100644 index 0000000000..b74b50b92e --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_subtype_rejection.snap @@ -0,0 +1,89 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Generic" +Word "subtype" +Word "rejection" +EqualsEquals "==" +Equals "=" +Slash "/" +Slash "/" +Word "Mismatched" +Word "instantiations" +Word "of" +Word "the" +Word "same" +Class "class" +Word "must" +Word "be" +Word "rejected" +Dot "." +Class "class" +Word "Box" +Less "<" +Word "T" +Greater ">" +LBrace "{" +Word "value" +Colon ":" +Word "T" +RBrace "}" +Function "function" +Word "take_string_box" +LParen "(" +Word "b" +Colon ":" +Word "Box" +Less "<" +Word "string" +Greater ">" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "null" +RBrace "}" +Slash "/" +Slash "/" +Word "Should" +Word "produce" +Word "a" +Word "type" +Word "error" +Colon ":" +Word "Box" +Less "<" +Word "int" +Greater ">" +Word "is" +Word "not" +Word "assignable" +Word "to" +Word "Box" +Less "<" +Word "string" +Greater ">" +Function "function" +Word "bad_box_mismatch" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "take_string_box" +LParen "(" +Word "Box" +Less "<" +Word "int" +Greater ">" +LBrace "{" +Word "value" +Colon ":" +IntegerLiteral "1" +RBrace "}" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_function_inference.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_function_inference.snap new file mode 100644 index 0000000000..c2c0f7b0a4 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_function_inference.snap @@ -0,0 +1,122 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + CLASS_DEF + KW_CLASS "class" + WORD "Wrapper" + GENERIC_PARAM_LIST + LESS "<" + GENERIC_PARAM "T" + WORD "T" + GREATER ">" + L_BRACE "{" + FIELD + WORD "inner" + COLON ":" + TYPE_EXPR "T" + WORD "T" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "unwrap" + GENERIC_PARAM_LIST + LESS "<" + GENERIC_PARAM "T" + WORD "T" + GREATER ">" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "w" + COLON ":" + TYPE_EXPR + WORD "Wrapper" + TYPE_ARGS + LESS "<" + TYPE_EXPR "T" + WORD "T" + GREATER ">" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "T" + WORD "T" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + PATH_EXPR "w.inner" + WORD "w" + DOT "." + WORD "inner" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_unwrap_int" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "unwrap" + CALL_ARGS + L_PAREN "(" + OBJECT_LITERAL + PATH_EXPR + WORD "Wrapper" + GENERIC_ARGS + LESS "<" + TYPE_EXPR "int" + WORD "int" + GREATER ">" + L_BRACE "{" + OBJECT_FIELD "inner: 42" + WORD "inner" + COLON ":" + INTEGER_LITERAL "42" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_unwrap_string" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "unwrap" + CALL_ARGS + L_PAREN "(" + OBJECT_LITERAL + PATH_EXPR + WORD "Wrapper" + GENERIC_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + L_BRACE "{" + OBJECT_FIELD + WORD "inner" + COLON ":" + STRING_LITERAL "hello" + QUOTE """ + WORD "hello" + QUOTE """ + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_subtype_rejection.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_subtype_rejection.snap new file mode 100644 index 0000000000..3469e6ed98 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_subtype_rejection.snap @@ -0,0 +1,81 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + CLASS_DEF + KW_CLASS "class" + WORD "Box" + GENERIC_PARAM_LIST + LESS "<" + GENERIC_PARAM "T" + WORD "T" + GREATER ">" + L_BRACE "{" + FIELD + WORD "value" + COLON ":" + TYPE_EXPR "T" + WORD "T" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "take_string_box" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "b" + COLON ":" + TYPE_EXPR + WORD "Box" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR "{ + null +}" + L_BRACE "{" + WORD "null" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "bad_box_mismatch" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "take_string_box" + CALL_ARGS + L_PAREN "(" + OBJECT_LITERAL + PATH_EXPR + WORD "Box" + GENERIC_ARGS + LESS "<" + TYPE_EXPR "int" + WORD "int" + GREATER ">" + L_BRACE "{" + OBJECT_FIELD "value: 1" + WORD "value" + COLON ":" + INTEGER_LITERAL "1" + R_BRACE "}" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index 6e4a42dd72..8408609079 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -224,6 +224,18 @@ type user.Wrapper = (() -> int) -> int function user.use_alias_pure(f: user.AliasPure) -> int [expr] { { } f() } +class user.Wrapper { + inner: user.T +} +function user.test_unwrap_int() -> int [expr] { + { } unwrap(user.Wrapper { inner: 42 }) +} +function user.test_unwrap_string() -> string [expr] { + { } unwrap(user.Wrapper { inner: "hello" }) +} +function user.unwrap(w: user.Wrapper) -> user.T [expr] { + { } w.inner +} class user.Handler { run: () -> null } @@ -245,6 +257,15 @@ function user.use_int_handler(h: user.Handler) -> null [expr] { function user.use_string_handler(h: user.Handler) -> null [expr] { { } h.run() } +class user.Box { + value: user.T +} +function user.bad_box_mismatch() -> null [expr] { + { } take_string_box(user.Box { value: 1 }) +} +function user.take_string_box(b: user.Box) -> null [expr] { + { } null +} class user.Task { name: string run: () -> null diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index 92fc402c2f..c436d1c1cd 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -1547,6 +1547,53 @@ fn user.use_alias_pure(f: () -> int) -> int { } } +fn user.test_unwrap_int() -> int { + // Locals: + let _0: int // _0 // return + let _1: void + + bb0: { + _1 = Wrapper { const 42_i64 }; + _0 = call const fn user.unwrap(copy _1) -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.test_unwrap_string() -> string { + // Locals: + let _0: string // _0 // return + let _1: void + + bb0: { + _1 = Wrapper { const "hello" }; + _0 = call const fn user.unwrap(copy _1) -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.unwrap(w: (() -> int) -> int throws string) -> void { + // Locals: + let _0: void // _0 // return + let _1: (() -> int) -> int throws string // w // param + let _2: string + + bb0: { + _2 = const "inner"; + _0 = copy _1[_2]; + goto -> bb1; + } + + bb1: { + return; + } +} + fn user.make_int_handler() -> Handler { // Locals: let _0: Handler // _0 // return @@ -1639,7 +1686,7 @@ fn user.use_int_handler(h: Handler) -> null { // Locals: let _0: null // _0 // return let _1: Handler // h // param - let _2: () -> null + let _2: () -> null throws int bb0: { _2 = copy _1.0; @@ -1655,7 +1702,7 @@ fn user.use_string_handler(h: Handler) -> null { // Locals: let _0: null // _0 // return let _1: Handler // h // param - let _2: () -> null + let _2: () -> null throws string bb0: { _2 = copy _1.0; @@ -1667,6 +1714,36 @@ fn user.use_string_handler(h: Handler) -> null { } } +fn user.bad_box_mismatch() -> null { + // Locals: + let _0: null // _0 // return + let _1: Box + + bb0: { + _1 = Box { const 1_i64 }; + _0 = call const fn user.take_string_box(copy _1) -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.take_string_box(b: Box) -> null { + // Locals: + let _0: null // _0 // return + let _1: Box // b // param + + bb0: { + _0 = const null; + goto -> bb1; + } + + bb1: { + return; + } +} + fn user.make_mixed_tasks() -> Task[] { // Locals: let _0: Task[] // _0 // return @@ -1719,7 +1796,7 @@ fn user.run_tasks(tasks: Task[]) -> null { let _5: Task let _6: Task // task let _7: void - let _8: () -> null + let _8: () -> null throws string | int let _9: Task bb0: { diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index e8a9cd9a12..0ed68af65a 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -148,7 +148,7 @@ lambda user.make_throwing_handler { } function user.test_method_runner() -> int throws never { { : int - let runner = MethodRunner { value: 1 } : user.MethodRunner + let runner = MethodRunner { value: 1 } : user.MethodRunner runner.apply((x: int) -> int { ... }) : int (x: int) -> int { ... } : (x: int) -> int { @@ -638,29 +638,52 @@ type user.NoThrow$stream = never type user.AliasPure$stream = unknown type user.Mapper$stream = unknown type user.Wrapper$stream = unknown +class user.Wrapper { + inner: T +} +function user.unwrap(w: user.Wrapper) -> T throws never { + { : unknown + w.inner : unknown + } + !! 201..206: type `(() -> int) -> int throws string` has no member `inner` + !! 196..206: type mismatch: expected T, got unknown +} +function user.test_unwrap_int() -> int throws never { + { : int + unwrap(Wrapper { inner: 42 }) : int + } +} +function user.test_unwrap_string() -> string throws never { + { : string + unwrap(Wrapper { inner: "hello" }) : string + } +} +class user.Wrapper$stream { + inner: null | unknown +} class user.Handler { run: () -> null throws E } -function user.use_string_handler(h: user.Handler) -> null throws never { +function user.use_string_handler(h: user.Handler) -> null throws string { { : null h.run() : null } } -function user.use_int_handler(h: user.Handler) -> null throws never { +function user.use_int_handler(h: user.Handler) -> null throws int { { : null h.run() : null } } -function user.make_string_handler() -> user.Handler throws never { - { : user.Handler - Handler { run: () -> null { ... } } : user.Handler +function user.make_string_handler() -> user.Handler throws never { + { : user.Handler + Handler { run: () -> null { ... } } : user.Handler } } lambda user.make_string_handler { } -function user.make_int_handler() -> user.Handler throws never { - { : user.Handler - Handler { run: () -> null { ... } } : user.Handler +function user.make_int_handler() -> user.Handler throws never { + { : user.Handler + Handler { run: () -> null { ... } } : user.Handler } } lambda user.make_int_handler { @@ -678,11 +701,28 @@ function user.test_int_handler() -> null throws never { class user.Handler$stream { run: unknown } +class user.Box { + value: T +} +function user.take_string_box(b: user.Box) -> null throws never { + { : null + null : null + } +} +function user.bad_box_mismatch() -> null throws never { + { : null + take_string_box(Box { value: 1 }) : null + } + !! 323..344: type mismatch: expected user.Box, got user.Box +} +class user.Box$stream { + value: null | unknown +} class user.Task { name: string run: () -> null throws E } -function user.run_tasks(tasks: user.Task[]) -> null throws never { +function user.run_tasks(tasks: user.Task[]) -> null throws int | string { { : null for task in tasks { : null @@ -691,9 +731,9 @@ function user.run_tasks(tasks: user.Task[]) -> null throws never { null : null } } -function user.make_mixed_tasks() -> user.Task[] throws never { - { : user.Task[] - [Task { name: "a", run: () -> null { ... } }, Task { name: "b", run: () -> null { ... } }] : user.Task[] +function user.make_mixed_tasks() -> user.Task[] throws never { + { : user.Task[] + [Task { name: "a", run: () -> null { ... } }, Task { name: "b", run: () -> null { ... } }] : user.Task[] } } lambda user.make_mixed_tasks { diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index e3a5e8b997..8c6f45604e 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -42,6 +42,53 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ + [validation] Error: Name `Wrapper` defined 2 times as: type, class + ╭─[ fn_type_alias_throws.baml:20:6 ] + │ + 20 │ type Wrapper = (() -> int) -> int throws string + │ ───┬─── + │ ╰───── first defined as type here + │ + ├─[ generic_function_inference.baml:4:7 ] + │ + 4 │ class Wrapper { + │ ───┬─── + │ ╰───── duplicate class definition + │ + │ Note: Error code: E0011 +────╯ + + [type] Error: type mismatch: expected T, got unknown + ╭─[ generic_function_inference.baml:8:41 ] + │ + 8 │ ╭─▶ function unwrap(w: Wrapper) -> T { + 9 │ ├─▶ w.inner + │ │ + │ ╰─────────────── type mismatch: expected T, got unknown + │ + │ Note: Error code: E0001 +───╯ + + [type] Error: type `(() -> int) -> int throws string` has no member `inner` + ╭─[ generic_function_inference.baml:9:5 ] + │ + 9 │ w.inner + │ ──┬── + │ ╰──── type `(() -> int) -> int throws string` has no member `inner` + │ + │ Note: Error code: E0001 +───╯ + + [type] Error: type mismatch: expected user.Box, got user.Box + ╭─[ generic_subtype_rejection.baml:14:19 ] + │ + 14 │ take_string_box(Box { value: 1 }) + │ ──────────┬────────── + │ ╰──────────── type mismatch: expected user.Box, got user.Box + │ + │ Note: Error code: E0001 +────╯ + [type] Error: throws contract violation: `never` is missing throws from callback parameter `f` ╭─[ hof_throws.baml:58:57 ] │ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 2462701429..298311c96d 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -155,6 +155,14 @@ function user.apply_with_many_args(a: int, b: string, f: (int, string) -> string return } +function user.bad_box_mismatch() -> null { + alloc_instance Box + load_const 1 + init_field .value + call user.take_string_box + return +} + function user.caller() -> string { load_const 1 call user.declared_may_fail @@ -538,6 +546,11 @@ function user.safe_caller() -> string { return } +function user.take_string_box(b: Box) -> null { + load_const null + return +} + function user.takes_throwing(f: () -> int throws string) -> int { load_var f call_indirect @@ -1019,6 +1032,22 @@ function user.test_typed_lambda_throws_mismatch() -> int { return } +function user.test_unwrap_int() -> int { + alloc_instance Wrapper + load_const 42 + init_field .inner + call user.unwrap + return +} + +function user.test_unwrap_string() -> string { + alloc_instance Wrapper + load_const "hello" + init_field .inner + call user.unwrap + return +} + function user.test_use_pure() -> int { call user.make_pure call_indirect @@ -1241,6 +1270,13 @@ function user.throw_various_errors(x: int) -> string { throw } +function user.unwrap(w: (() -> int) -> int throws string) -> void { + load_var w + load_const "inner" + load_map_element + return +} + function user.use_alias_pure(f: () -> int) -> int { load_var f call_indirect diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_function_inference.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_function_inference.snap new file mode 100644 index 0000000000..6404b395b4 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_function_inference.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +An element at generic_function_inference.baml:4:14 was a token when it should have been a node. diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_subtype_rejection.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_subtype_rejection.snap new file mode 100644 index 0000000000..6fac961f6a --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__generic_subtype_rejection.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== STRONG AST ERROR === +An element at generic_subtype_rejection.baml:4:10 was a token when it should have been a node. diff --git a/baml_language/crates/baml_tests/snapshots/test_with_runner_ambiguity/baml_tests__test_with_runner_ambiguity__04_tir.snap b/baml_language/crates/baml_tests/snapshots/test_with_runner_ambiguity/baml_tests__test_with_runner_ambiguity__04_tir.snap index ea7e253b37..262e4bfc6d 100644 --- a/baml_language/crates/baml_tests/snapshots/test_with_runner_ambiguity/baml_tests__test_with_runner_ambiguity__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/test_with_runner_ambiguity/baml_tests__test_with_runner_ambiguity__04_tir.snap @@ -26,6 +26,7 @@ function user.$init_test_19(registry: testing.TestCollector) -> ? throws never { } !! 384..393: unresolved name: my_runner !! 512..527: unresolved name: testing.Quorum + !! 770..797: type mismatch: expected testing.TestRunner?, got user.MyRunner } lambda user.$init_test_19 { } diff --git a/baml_language/crates/baml_tests/snapshots/test_with_runner_ambiguity/baml_tests__test_with_runner_ambiguity__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/test_with_runner_ambiguity/baml_tests__test_with_runner_ambiguity__05_diagnostics.snap index 284e993140..d4de4a5cfa 100644 --- a/baml_language/crates/baml_tests/snapshots/test_with_runner_ambiguity/baml_tests__test_with_runner_ambiguity__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/test_with_runner_ambiguity/baml_tests__test_with_runner_ambiguity__05_diagnostics.snap @@ -22,3 +22,14 @@ source: crates/baml_tests/src/generated_tests.rs │ │ Note: Error code: E0001 ────╯ + + [type] Error: type mismatch: expected testing.TestRunner?, got user.MyRunner + ╭─[ main.baml:23:40 ] + │ + 23 │ ╭─▶ test "constructor runner" with MyRunner { threshold: 0.5 } { + 24 │ ├─▶ assert.is_true(true) + │ │ + │ ╰──────────────────────────── type mismatch: expected testing.TestRunner?, got user.MyRunner + │ + │ Note: Error code: E0001 +────╯ diff --git a/baml_language/crates/baml_tests/snapshots/type_aliases/baml_tests__type_aliases__04_tir.snap b/baml_language/crates/baml_tests/snapshots/type_aliases/baml_tests__type_aliases__04_tir.snap index 25dc02d70c..f187432cb7 100644 --- a/baml_language/crates/baml_tests/snapshots/type_aliases/baml_tests__type_aliases__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/type_aliases/baml_tests__type_aliases__04_tir.snap @@ -33,12 +33,12 @@ function user.EchoJSON(i: user.JSON) -> user.JSON throws never { { : user.NotJSON NotJSON { inner: "Whoops" } : user.NotJSON } + !! 165..191: type mismatch: expected user.JSON, got user.NotJSON } function user.MakeJSON(i: user.NotJSON) -> user.JSON throws never { { : unknown i.inner : unknown } - !! 92..97: unresolved type: name } type user.JSON$stream = int | float | string | bool | user.JSON$stream[] | map !! 11..67: unresolved type: json diff --git a/baml_language/crates/baml_tests/snapshots/type_aliases/baml_tests__type_aliases__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/type_aliases/baml_tests__type_aliases__05_diagnostics.snap index 20c1512d8d..3562c04342 100644 --- a/baml_language/crates/baml_tests/snapshots/type_aliases/baml_tests__type_aliases__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/type_aliases/baml_tests__type_aliases__05_diagnostics.snap @@ -29,15 +29,15 @@ source: crates/baml_tests/src/generated_tests.rs │ ──┬── │ ╰──── unresolved type: name │ - │ Note: Error code: E0001 + │ Note: Error code: E0002 ───╯ - [type] Error: unresolved type: name - ╭─[ type_aliases_json.baml:4:8 ] + [type] Error: type mismatch: expected user.JSON, got user.NotJSON + ╭─[ type_aliases_json.baml:9:3 ] │ - 4 │ inner name - │ ──┬── - │ ╰──── unresolved type: name + 9 │ NotJSON{ inner: "Whoops" } + │ ─────────────┬──────────── + │ ╰────────────── type mismatch: expected user.JSON, got user.NotJSON │ - │ Note: Error code: E0002 + │ Note: Error code: E0001 ───╯ diff --git a/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs b/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs index 0e379f1211..78f1b67064 100644 --- a/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs +++ b/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs @@ -1023,7 +1023,8 @@ pub(crate) mod support { baml_compiler2_tir::ty::Ty::Class( baml_compiler2_tir::lower_type_expr::qualify_def( db, def, cn, - ), + ) + .into(), Default::default(), ) }) diff --git a/baml_language/crates/baml_tests/src/compiler2_tir/phase5.rs b/baml_language/crates/baml_tests/src/compiler2_tir/phase5.rs index d6707ac771..e7c2a2ab5c 100644 --- a/baml_language/crates/baml_tests/src/compiler2_tir/phase5.rs +++ b/baml_language/crates/baml_tests/src/compiler2_tir/phase5.rs @@ -554,6 +554,7 @@ fn cross_namespace_type_resolution_via_root() { &db, &baml_compiler2_ast::TypeExpr::Path { segments, + type_args: vec![], attrs: vec![], }, pkg_items, @@ -579,6 +580,7 @@ fn cross_namespace_type_resolution_via_root() { &db, &baml_compiler2_ast::TypeExpr::Path { segments, + type_args: vec![], attrs: vec![], }, pkg_items, @@ -618,6 +620,7 @@ fn same_namespace_resolution_no_prefix() { &db, &baml_compiler2_ast::TypeExpr::Path { segments, + type_args: vec![], attrs: vec![], }, pkg_items, @@ -673,6 +676,7 @@ fn nested_namespace_resolution() { &db, &baml_compiler2_ast::TypeExpr::Path { segments, + type_args: vec![], attrs: vec![], }, pkg_items, @@ -710,6 +714,7 @@ fn bare_name_cross_namespace_rejected() { &db, &baml_compiler2_ast::TypeExpr::Path { segments, + type_args: vec![], attrs: vec![], }, pkg_items, @@ -752,6 +757,7 @@ fn multi_segment_bare_path_rejected() { &db, &baml_compiler2_ast::TypeExpr::Path { segments, + type_args: vec![], attrs: vec![], }, pkg_items, From 53b47b85ebf21620095b3b485e3d26171046dbc6 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 12:50:05 -0500 Subject: [PATCH 19/26] Propagate named-function throws through typed member field calls Add collect_member_field_call_throws to the pre-inference throw pipeline so that throws from typed member calls (h.run(), self.run(), task.run()) propagate to callers. Uses PackageResolutionContext for member lookup with generic substitution applied. test_string_handler now correctly infers throws string, test_int_handler infers throws int, and test_run_mixed_tasks infers throws int | string via transitive propagation. --- .../baml_compiler2_tir/src/throw_inference.rs | 229 +++++++++++++++++- ...l_tests__function_type_throws__04_tir.snap | 6 +- 2 files changed, 227 insertions(+), 8 deletions(-) diff --git a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs index f9188f8aad..8b7802ab26 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs @@ -67,7 +67,20 @@ pub fn function_throw_sets<'db>( package_id: PackageId<'db>, ) -> FunctionThrowSets { let pkg_items = package_items(db, package_id); - let aliases = collect_type_aliases(db, pkg_items); + let res_ctx = crate::package_interface::package_resolution_context(db, package_id); + let mut aliases = collect_type_aliases(db, pkg_items); + // Merge dependency type aliases for cross-package field type resolution + for (_dep_name, dep_iface) in &res_ctx.dep_interfaces { + for types_in_ns in dep_iface.types.values() { + for exported in types_in_ns.values() { + if let crate::package_interface::ExportedType::TypeAlias { qtn, resolved } = + exported + { + aliases.insert(qtn.clone(), resolved.clone()); + } + } + } + } // Load dependency interfaces for cross-package throw lookup let dep_interfaces: Vec<(Name, &crate::package_interface::PackageInterface)> = package_dependencies(db, package_id) @@ -128,6 +141,17 @@ pub fn function_throw_sets<'db>( expr_body, &aliases, )); + let (member_facts, _) = collect_member_field_call_throws( + db, + res_ctx, + &func_ns, + &func_data.generic_params, + sig.as_ref(), + expr_body, + &aliases, + None, + ); + direct.extend(member_facts); direct } else { BTreeSet::new() @@ -137,7 +161,20 @@ pub fn function_throw_sets<'db>( has_declared_contract.insert(key.clone(), declared_throws.is_some()); if let baml_compiler2_hir::body::FunctionBody::Expr(expr_body) = body.as_ref() { - call_edges.insert(key, collect_call_targets(expr_body)); + // Build combined target set: syntactic call targets + member call edges + let mut targets = collect_call_targets(expr_body); + let (_, member_edges) = collect_member_field_call_throws( + db, + res_ctx, + &func_ns, + &func_data.generic_params, + sig.as_ref(), + expr_body, + &aliases, + None, + ); + targets.extend(member_edges); + call_edges.insert(key, targets); } } @@ -194,6 +231,17 @@ pub fn function_throw_sets<'db>( expr_body, &aliases, )); + let (member_facts, _) = collect_member_field_call_throws( + db, + res_ctx, + &method_ns, + &method_generic_params, + sig.as_ref(), + expr_body, + &aliases, + Some((class_name, class_data.generic_params.as_slice())), + ); + direct.extend(member_facts); direct } else { BTreeSet::new() @@ -203,10 +251,21 @@ pub fn function_throw_sets<'db>( has_declared_contract.insert(key.clone(), declared_throws.is_some()); if let baml_compiler2_hir::body::FunctionBody::Expr(expr_body) = body.as_ref() { - // Rewrite "self.X" call targets to "ClassName.X" so edges - // connect to the correct graph nodes. let raw_targets = collect_call_targets(expr_body); - let rewritten: BTreeSet = raw_targets + let (_, member_edges) = collect_member_field_call_throws( + db, + res_ctx, + &method_ns, + &method_generic_params, + sig.as_ref(), + expr_body, + &aliases, + Some((class_name, class_data.generic_params.as_slice())), + ); + // Merge syntactic targets + member edges, then rewrite self references + let mut combined = raw_targets; + combined.extend(member_edges); + let rewritten: BTreeSet = combined .into_iter() .map(|t| rewrite_self_target(&t, class_name)) .collect(); @@ -546,6 +605,166 @@ fn collect_direct_param_call_throws<'db>( out } +/// Extract direct throw facts from member field calls on typed parameters/self. +/// +/// Handles two sub-cases: +/// 1. Function-typed field calls (e.g., `h.run()` where `run` is a `Class::fields` entry): +/// resolve the base type from the function signature, look up the field via +/// `PackageResolutionContext::lookup_class_fields` (with generic substitution applied), +/// and extract throws as direct facts if the field is `Ty::Function`. +/// 2. Named method calls (e.g., `h.do_thing()` where `do_thing` is a `Class::methods` entry): +/// rewrite the call target to namespace-qualified `"ns.ClassName.method"` form and add +/// as a call-graph edge. +#[allow(clippy::too_many_arguments)] +fn collect_member_field_call_throws<'db>( + db: &'db dyn crate::Db, + res_ctx: &crate::package_interface::PackageResolutionContext<'db>, + ns_context: &[Name], + generic_params: &[Name], + sig: &baml_compiler2_hir::signature::FunctionSignature, + body: &ExprBody, + aliases: &HashMap, + // (class_name, class_generic_params) — needed to build accurate self type + // with TypeVar type_args, e.g. Handler not bare Handler + class_context: Option<(&Name, &[Name])>, +) -> (BTreeSet, BTreeSet) { + let mut direct_facts = BTreeSet::new(); + let mut extra_edges = BTreeSet::new(); + + let pkg_items = res_ctx.own_items; + // Build param name → Ty map from the function signature. + let param_types: HashMap = sig + .params + .iter() + .map(|(name, te)| { + let mut diags = Vec::new(); + let ty = + lower_type_expr_in_ns(db, te, pkg_items, ns_context, generic_params, &mut diags); + (name.clone(), ty) + }) + .collect(); + + // Extend with for-loop variable types. + // For `for (let elem in collection)` where `collection` is a known param of type `T[]`, + // map `elem` → `T`. This handles patterns like `for task in tasks` where + // `tasks: Task[]`. + let mut local_types: HashMap = param_types.clone(); + for (_, stmt) in body.stmts.iter() { + if let baml_compiler2_ast::Stmt::For { + binding, + collection, + .. + } = stmt + { + // Get the binding name + let binding_name = match &body.patterns[*binding] { + baml_compiler2_ast::Pattern::Binding(name) + | baml_compiler2_ast::Pattern::TypedBinding { name, .. } => name.clone(), + _ => continue, + }; + // Get the collection's type by resolving its path to a param + let collection_ty = match &body.exprs[*collection] { + Expr::Path(segments) if segments.len() == 1 => { + param_types.get(&segments[0]).cloned() + } + _ => None, + }; + if let Some(coll_ty) = collection_ty { + // Peel List type to get element type + let elem_ty = match coll_ty { + Ty::List(inner, _) => Some(*inner), + _ => None, + }; + if let Some(elem) = elem_ty { + local_types.insert(binding_name, elem); + } + } + } + } + + for (_, expr) in body.exprs.iter() { + let callee_id = match expr { + Expr::Call { callee, .. } | Expr::OptionalCall { callee, .. } => *callee, + _ => continue, + }; + + // Match FieldAccess { base, field } pattern + let (base_id, field) = match &body.exprs[callee_id] { + Expr::FieldAccess { base, field } => (*base, field), + _ => continue, + }; + + // Resolve base to a type: direct parameter, for-loop variable, or `self` + let base_ty = match &body.exprs[base_id] { + Expr::Path(segments) if segments.len() == 1 => { + let base_name = &segments[0]; + if base_name.as_str() == "self" { + // self-based member calls — resolve self type from class_context. + // Build accurate self type with TypeVar type_args so that + // lookup_class_fields applies correct generic substitution. + class_context.and_then(|(class_name, class_generic_params)| { + let def = pkg_items.lookup_type(ns_context, class_name)?; + match def { + Definition::Class(_) => { + let qtn = qualify_def(db, def, class_name); + let type_args: Vec = class_generic_params + .iter() + .map(|name| Ty::TypeVar(name.clone(), TyAttr::default())) + .collect(); + let nominal = + crate::ty::NominalTypeRef::new_with_type_args(qtn, type_args); + Some(Ty::Class(nominal, TyAttr::default())) + } + _ => None, + } + }) + } else { + local_types.get(base_name).cloned() + } + } + _ => None, + }; + + let Some(base_ty) = base_ty else { + continue; + }; + + // Resolve the member on the base type + let resolved_base = crate::throws_semantics::resolve_alias_chain(&base_ty, aliases); + let Ty::Class(class_name, _) = &resolved_base else { + continue; + }; + + // Look up the field via PackageResolutionContext (with generic substitution) + let fields = res_ctx.lookup_class_fields(db, class_name); + if let Some((_, field_ty)) = fields.iter().find(|(n, _)| n == field) { + // Function-typed field → extract throws as direct facts + if let Some(facts) = function_throws_facts(field_ty, aliases) { + direct_facts.extend(facts.into_iter().filter(|fact| { + !matches!(fact, Ty::TypeVar(_, _) | Ty::Never { .. } | Ty::Void { .. }) + })); + } + continue; + } + + // Check if it's a named method → validate via res_ctx and add as call-graph edge. + // We MUST validate the method exists before adding an edge, because + // AnalysisGraph::add_edge (analysis.rs:54-58) auto-creates missing target + // nodes with empty facts. A bogus edge would inject a phantom node and + // alter graph topology, potentially masking real throw propagation. + // We also must use the class declaration namespace for the key, NOT the caller's + // ns_context, so the edge matches the declaration-side key built by function_key(). + if let Some(_resolved_method) = res_ctx.lookup_class_method(db, class_name, field) { + let class_ns = class_name.namespace(); + let method_short = Name::new(format!("{}.{}", class_name.name(), field)); + let method_key = throw_set_key(class_ns, &method_short); + extra_edges.insert(method_key); + } + } + + (direct_facts, extra_edges) +} + /// Look up a function's transitive throw set from dependency interfaces. fn lookup_dep_throw_set<'a>( dep_interfaces: &'a [(Name, &crate::package_interface::PackageInterface)], diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 0ed68af65a..bb01c08a04 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -688,12 +688,12 @@ function user.make_int_handler() -> user.Handler throws never { } lambda user.make_int_handler { } -function user.test_string_handler() -> null throws never { +function user.test_string_handler() -> null throws string { { : null use_string_handler(make_string_handler()) : null } } -function user.test_int_handler() -> null throws never { +function user.test_int_handler() -> null throws int { { : null use_int_handler(make_int_handler()) : null } @@ -740,7 +740,7 @@ lambda user.make_mixed_tasks { } lambda user.make_mixed_tasks { } -function user.test_run_mixed_tasks() -> null throws never { +function user.test_run_mixed_tasks() -> null throws int | string { { : null run_tasks(make_mixed_tasks()) : null } From 1bb6e8bf2068c6258fb236502c1bf03b1473d0e3 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 15:29:37 -0500 Subject: [PATCH 20/26] Render generic args in LSP hover and type display Update display_type_expr to render type_args recursively for TypeExpr::Path (e.g. Handler instead of Handler). Include generic_params in hover tooltips for class and function definitions so class Handler and function unwrap show their type parameters. Add LSP hover test infrastructure (CursorTest::type_info) and 11 tests covering generic class/function definitions, usage-site resolution, nested generics, optional/array generic params, and markdown rendering. --- .../crates/baml_lsp2_actions/src/lib.rs | 2 + .../crates/baml_lsp2_actions/src/testing.rs | 5 + .../crates/baml_lsp2_actions/src/type_info.rs | 25 +- .../baml_lsp2_actions/src/type_info_tests.rs | 256 ++++++++++++++++++ .../crates/baml_lsp2_actions/src/utils.rs | 17 +- 5 files changed, 299 insertions(+), 6 deletions(-) create mode 100644 baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs diff --git a/baml_language/crates/baml_lsp2_actions/src/lib.rs b/baml_language/crates/baml_lsp2_actions/src/lib.rs index 0a426625ec..337536fcdf 100644 --- a/baml_language/crates/baml_lsp2_actions/src/lib.rs +++ b/baml_language/crates/baml_lsp2_actions/src/lib.rs @@ -58,6 +58,8 @@ mod definition_at_tests; #[cfg(test)] mod testing; #[cfg(test)] +mod type_info_tests; +#[cfg(test)] mod usages_at_tests; // ── Db trait ────────────────────────────────────────────────────────────────── diff --git a/baml_language/crates/baml_lsp2_actions/src/testing.rs b/baml_language/crates/baml_lsp2_actions/src/testing.rs index adb921bbd5..511e489971 100644 --- a/baml_language/crates/baml_lsp2_actions/src/testing.rs +++ b/baml_language/crates/baml_lsp2_actions/src/testing.rs @@ -144,6 +144,11 @@ impl CursorTest { definition_at(&self.db, self.cursor.file, self.cursor.offset) } + /// Get hover type info at the cursor position. + pub(crate) fn type_info(&self) -> Option { + crate::type_info::type_at(&self.db, self.cursor.file, self.cursor.offset) + } + /// Find all usages/references at the cursor position. pub(crate) fn find_all_usages(&self) -> Vec { usages_at(&self.db, self.cursor.file, self.cursor.offset) diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info.rs b/baml_language/crates/baml_lsp2_actions/src/type_info.rs index 3d4625694a..610bcbdcb5 100644 --- a/baml_language/crates/baml_lsp2_actions/src/type_info.rs +++ b/baml_language/crates/baml_lsp2_actions/src/type_info.rs @@ -193,6 +193,18 @@ fn type_info_for_definition(db: &dyn Db, def: Definition<'_>) -> TypeInfo { match def { Definition::Function(func_loc) => { let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); + let item_tree = baml_compiler2_hir::file_item_tree(db, func_loc.file(db)); + let func_data = &item_tree[func_loc.id(db)]; + let func_name = if func_data.generic_params.is_empty() { + sig.name.as_str().to_string() + } else { + let gparams: Vec<&str> = func_data + .generic_params + .iter() + .map(baml_base::Name::as_str) + .collect(); + format!("{}<{}>", sig.name.as_str(), gparams.join(", ")) + }; let params = sig .params .iter() @@ -205,7 +217,7 @@ fn type_info_for_definition(db: &dyn Db, def: Definition<'_>) -> TypeInfo { .collect(); let return_type = sig.return_type.as_ref().map(utils::display_type_expr); TypeInfo::Function { - name: sig.name.as_str().to_string(), + name: func_name, params, return_type, } @@ -214,7 +226,16 @@ fn type_info_for_definition(db: &dyn Db, def: Definition<'_>) -> TypeInfo { Definition::Class(class_loc) => { let item_tree = baml_compiler2_hir::file_item_tree(db, class_loc.file(db)); let class_data = &item_tree[class_loc.id(db)]; - let class_name = class_data.name.as_str().to_string(); + let class_name = if class_data.generic_params.is_empty() { + class_data.name.as_str().to_string() + } else { + let params: Vec<&str> = class_data + .generic_params + .iter() + .map(baml_base::Name::as_str) + .collect(); + format!("{}<{}>", class_data.name.as_str(), params.join(", ")) + }; // Use resolved field types (Salsa-cached). let resolved = baml_compiler2_tir::inference::resolve_class_fields(db, class_loc); diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs new file mode 100644 index 0000000000..238c550bd8 --- /dev/null +++ b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs @@ -0,0 +1,256 @@ +#[cfg(test)] +mod tests { + use crate::{testing::CursorTest, type_info::TypeInfo}; + + // ── Class definition hover ────────────────────────────────────────────── + + #[test] + fn hover_class_no_generics() { + let test = CursorTest::new( + r#" +class <[CURSOR]Person { + name: string +} +"#, + ); + let info = test.type_info().expect("should resolve"); + assert_eq!( + info, + TypeInfo::Class { + name: "Person".into(), + fields: vec![("name".into(), "string".into())], + } + ); + } + + #[test] + fn hover_generic_class_definition_shows_type_params() { + let test = CursorTest::new( + r#" +class <[CURSOR]Handler { + run: () -> null throws E +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Class { name, .. } => { + assert_eq!( + name, "Handler", + "Class hover should include generic params" + ); + } + other => panic!("Expected TypeInfo::Class, got: {other:?}"), + } + } + + #[test] + fn hover_generic_class_multiple_type_params() { + let test = CursorTest::new( + r#" +class <[CURSOR]Pair { + first: A + second: B +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Class { name, .. } => { + assert_eq!( + name, "Pair", + "Class hover should show all generic params" + ); + } + other => panic!("Expected TypeInfo::Class, got: {other:?}"), + } + } + + // ── Class usage-site hover (resolves to definition) ───────────────────── + + #[test] + fn hover_generic_class_at_usage_site() { + // Hovering over `Handler` in a param type resolves to the class definition. + // The hover should show `Handler` (the definition's generic params). + let test = CursorTest::new( + r#" +class Handler { + run: () -> null throws E +} + +function use_handler(h: <[CURSOR]Handler) -> null { + h.run() +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Class { name, .. } => { + assert_eq!( + name, "Handler", + "Usage-site hover resolves to definition, should show generic params" + ); + } + other => panic!("Expected TypeInfo::Class, got: {other:?}"), + } + } + + // ── Function definition hover ─────────────────────────────────────────── + + #[test] + fn hover_function_no_generics() { + let test = CursorTest::new( + r#" +function <[CURSOR]greet(name: string) -> string { + name +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { name, .. } => { + assert_eq!(name, "greet"); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + } + + #[test] + fn hover_generic_function_shows_type_params() { + let test = CursorTest::new( + r#" +class Wrapper { + inner: T +} + +function <[CURSOR]unwrap(w: Wrapper) -> T { + w.inner +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { name, .. } => { + assert_eq!( + name, "unwrap", + "Function hover should include generic params" + ); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + } + + // ── Hover markdown rendering ──────────────────────────────────────────── + + #[test] + fn hover_markdown_generic_class() { + let test = CursorTest::new( + r#" +class <[CURSOR]Handler { + run: () -> null throws E +} +"#, + ); + let info = test.type_info().expect("should resolve"); + let md = info.to_hover_markdown(); + assert!( + md.contains("class Handler"), + "Markdown should contain 'class Handler', got: {md}" + ); + } + + #[test] + fn hover_markdown_generic_function() { + let test = CursorTest::new( + r#" +class Wrapper { + inner: T +} + +function <[CURSOR]unwrap(w: Wrapper) -> T { + w.inner +} +"#, + ); + let info = test.type_info().expect("should resolve"); + let md = info.to_hover_markdown(); + assert!( + md.contains("function unwrap"), + "Markdown should contain 'function unwrap', got: {md}" + ); + // The param type should render with generic args thanks to display_type_expr + assert!( + md.contains("Wrapper"), + "Param type should render as Wrapper, got: {md}" + ); + } + + #[test] + fn hover_markdown_function_with_generic_param_type() { + // Hover on `maybe_handler` — the function itself is not generic, + // but its parameter type `Handler?` has generic args. + let test = CursorTest::new( + r#" +class Handler { + run: () -> null throws E +} + +function <[CURSOR]maybe_handler(h: Handler?) -> null { + null +} +"#, + ); + let info = test.type_info().expect("should resolve"); + let md = info.to_hover_markdown(); + assert!( + md.contains("Handler?"), + "Param type should render as Handler?, got: {md}" + ); + } + + #[test] + fn hover_markdown_function_with_nested_generics() { + let test = CursorTest::new( + r#" +class Handler { + run: () -> null throws E +} + +class Wrapper { + inner: T +} + +function <[CURSOR]use_nested(w: Wrapper>) -> null { + null +} +"#, + ); + let info = test.type_info().expect("should resolve"); + let md = info.to_hover_markdown(); + assert!( + md.contains("Wrapper>"), + "Param type should show nested generics, got: {md}" + ); + } + + #[test] + fn hover_markdown_function_with_generic_array_param() { + let test = CursorTest::new( + r#" +class Handler { + run: () -> null throws E +} + +function <[CURSOR]many_handlers(hs: Handler[]) -> null { + null +} +"#, + ); + let info = test.type_info().expect("should resolve"); + let md = info.to_hover_markdown(); + assert!( + md.contains("Handler[]"), + "Param type should render as Handler[], got: {md}" + ); + } +} diff --git a/baml_language/crates/baml_lsp2_actions/src/utils.rs b/baml_language/crates/baml_lsp2_actions/src/utils.rs index 19469b9a54..5db2e1c431 100644 --- a/baml_language/crates/baml_lsp2_actions/src/utils.rs +++ b/baml_language/crates/baml_lsp2_actions/src/utils.rs @@ -209,12 +209,21 @@ pub fn display_ty(ty: &Ty) -> String { /// produces output that matches the user's source syntax. pub fn display_type_expr(te: &TypeExpr) -> String { match te { - TypeExpr::Path { segments, .. } => { - // Use only the last segment for brevity (e.g. `baml.Foo` → `Foo`). - segments + TypeExpr::Path { + segments, + type_args, + .. + } => { + let base = segments .last() .map(|n| n.as_str().to_string()) - .unwrap_or_else(|| "unknown".to_string()) + .unwrap_or_else(|| "unknown".to_string()); + if type_args.is_empty() { + base + } else { + let args: Vec<_> = type_args.iter().map(display_type_expr).collect(); + format!("{}<{}>", base, args.join(", ")) + } } TypeExpr::Int { .. } => "int".to_string(), TypeExpr::Float { .. } => "float".to_string(), From bc521f098503e8b433a79bce01718d3bcc6e65fc Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 15:38:33 -0500 Subject: [PATCH 21/26] Refresh TIR snapshots for StructuralTy type_args shape change MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Twelve 04_tir snapshots updated with trivially larger serialization from the new Vec type_args on Class/Enum variants. MIR snapshots (04_5_mir) are unchanged — no scope leak. --- .../baml_tests__basic_types__04_tir.snap | 24 ++++++------- .../baml_tests__comment_in_type__04_tir.snap | 12 +++---- .../baml_tests__control_flow__04_tir.snap | 6 ++-- .../baml_tests__format_checks__04_tir.snap | 36 +++++++++---------- .../baml_tests__generator__04_tir.snap | 6 ++-- ...tests__header_in_llm_function__04_tir.snap | 6 ++-- ...aml_tests__parser_speculative__04_tir.snap | 30 ++++++++-------- .../baml_tests__parser_strings__04_tir.snap | 12 +++---- ...l_tests__pending_greaters_fix__04_tir.snap | 12 +++---- .../baml_tests__simple_function__04_tir.snap | 6 ++-- ...ml_tests__type_builder_errors__04_tir.snap | 6 ++-- ...baml_tests__type_builder_test__04_tir.snap | 6 ++-- 12 files changed, 81 insertions(+), 81 deletions(-) diff --git a/baml_language/crates/baml_tests/snapshots/basic_types/baml_tests__basic_types__04_tir.snap b/baml_language/crates/baml_tests/snapshots/basic_types/baml_tests__basic_types__04_tir.snap index eb8225fa25..3a4688ad0d 100644 --- a/baml_language/crates/baml_tests/snapshots/basic_types/baml_tests__basic_types__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/basic_types/baml_tests__basic_types__04_tir.snap @@ -26,57 +26,57 @@ class user.FieldAttributes$stream { class user.ClassAttributes$stream { name: null | string } -function user.GetUser(id: int) -> user.User throws never { +function user.GetUser(id: int) -> user.User throws unknown { baml.llm.call_llm_function(baml.llm.Client { name: "openai/gpt-4o", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "GetUser", map { "id": id }) : user.User } -function user.GetUser$render_prompt(id: int) -> baml.llm.PromptAst throws never { +function user.GetUser$render_prompt(id: int) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(baml.llm.Client { name: "openai/gpt-4o", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "GetUser", map { "id": id }) : baml.llm.PromptAst } -function user.GetUser$build_request(id: int) -> baml.http.Request throws never { +function user.GetUser$build_request(id: int) -> baml.http.Request throws unknown { baml.llm.build_request(baml.llm.Client { name: "openai/gpt-4o", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "GetUser", map { "id": id }) : baml.http.Request } function user.GetUser$parse(json: string) -> user.User throws never { baml.llm.parse("GetUser", json) : user.User } -function user.GetStatus() -> user.Status throws never { +function user.GetStatus() -> user.Status throws unknown { baml.llm.call_llm_function(GPT4, "GetStatus", map { }) : user.Status !! 127..173: unresolved name: GPT4 } -function user.GetStatus$render_prompt() -> baml.llm.PromptAst throws never { +function user.GetStatus$render_prompt() -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "GetStatus", map { }) : baml.llm.PromptAst !! 95..173: unresolved name: GPT4 } -function user.GetStatus$build_request() -> baml.http.Request throws never { +function user.GetStatus$build_request() -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "GetStatus", map { }) : baml.http.Request !! 95..173: unresolved name: GPT4 } function user.GetStatus$parse(json: string) -> user.Status throws never { baml.llm.parse("GetStatus", json) : user.Status } -function user.GetMixedList() -> (int | string)[] throws never { +function user.GetMixedList() -> (int | string)[] throws unknown { baml.llm.call_llm_function(GPT4, "GetMixedList", map { }) : (int | string)[] !! 258..308: unresolved name: GPT4 } -function user.GetMixedList$render_prompt() -> baml.llm.PromptAst throws never { +function user.GetMixedList$render_prompt() -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "GetMixedList", map { }) : baml.llm.PromptAst !! 173..308: unresolved name: GPT4 } -function user.GetMixedList$build_request() -> baml.http.Request throws never { +function user.GetMixedList$build_request() -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "GetMixedList", map { }) : baml.http.Request !! 173..308: unresolved name: GPT4 } function user.GetMixedList$parse(json: string) -> (int | string)[] throws never { baml.llm.parse("GetMixedList", json) : (int | string)[] } -function user.GetMixedMap() -> map throws never { +function user.GetMixedMap() -> map throws unknown { baml.llm.call_llm_function(GPT4, "GetMixedMap", map { }) : map !! 361..410: unresolved name: GPT4 } -function user.GetMixedMap$render_prompt() -> baml.llm.PromptAst throws never { +function user.GetMixedMap$render_prompt() -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "GetMixedMap", map { }) : baml.llm.PromptAst !! 308..410: unresolved name: GPT4 } -function user.GetMixedMap$build_request() -> baml.http.Request throws never { +function user.GetMixedMap$build_request() -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "GetMixedMap", map { }) : baml.http.Request !! 308..410: unresolved name: GPT4 } diff --git a/baml_language/crates/baml_tests/snapshots/comment_in_type/baml_tests__comment_in_type__04_tir.snap b/baml_language/crates/baml_tests/snapshots/comment_in_type/baml_tests__comment_in_type__04_tir.snap index a8db214c40..464489e43c 100644 --- a/baml_language/crates/baml_tests/snapshots/comment_in_type/baml_tests__comment_in_type__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/comment_in_type/baml_tests__comment_in_type__04_tir.snap @@ -5,25 +5,25 @@ source: crates/baml_tests/src/generated_tests.rs function user.TestClient$new() -> ? throws never { baml.llm.PrimitiveClient { name: "TestClient", provider: "openai", options: baml.llm.PrimitiveClientOptions { model: null, base_url: "https://api.opena...", allowed_role_metadata: null, finish_reason_allow_list: null, finish_reason_deny_list: null, supports_streaming: null, default_role: "system", allowed_roles: ["system", "user", "assistant"], remap_roles: null, api_key: null, provider_options: null, headers: map { }, query_params: map { }, request_body: map { } } } : unknown } -function user.FunctionWithComments(a: string, b: int) -> string throws never { +function user.FunctionWithComments(a: string, b: int) -> string throws unknown { baml.llm.call_llm_function(TestClient, "FunctionWithComments", map { "a": a, "b": b }) : string } -function user.FunctionWithComments$render_prompt(a: string, b: int) -> baml.llm.PromptAst throws never { +function user.FunctionWithComments$render_prompt(a: string, b: int) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(TestClient, "FunctionWithComments", map { "a": a, "b": b }) : baml.llm.PromptAst } -function user.FunctionWithComments$build_request(a: string, b: int) -> baml.http.Request throws never { +function user.FunctionWithComments$build_request(a: string, b: int) -> baml.http.Request throws unknown { baml.llm.build_request(TestClient, "FunctionWithComments", map { "a": a, "b": b }) : baml.http.Request } function user.FunctionWithComments$parse(json: string) -> string throws never { baml.llm.parse("FunctionWithComments", json) : string } -function user.CommentAfterArrow() -> string throws never { +function user.CommentAfterArrow() -> string throws unknown { baml.llm.call_llm_function(TestClient, "CommentAfterArrow", map { }) : string } -function user.CommentAfterArrow$render_prompt() -> baml.llm.PromptAst throws never { +function user.CommentAfterArrow$render_prompt() -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(TestClient, "CommentAfterArrow", map { }) : baml.llm.PromptAst } -function user.CommentAfterArrow$build_request() -> baml.http.Request throws never { +function user.CommentAfterArrow$build_request() -> baml.http.Request throws unknown { baml.llm.build_request(TestClient, "CommentAfterArrow", map { }) : baml.http.Request } function user.CommentAfterArrow$parse(json: string) -> string throws never { diff --git a/baml_language/crates/baml_tests/snapshots/control_flow/baml_tests__control_flow__04_tir.snap b/baml_language/crates/baml_tests/snapshots/control_flow/baml_tests__control_flow__04_tir.snap index 1f6f1c7945..f4d88d7d2f 100644 --- a/baml_language/crates/baml_tests/snapshots/control_flow/baml_tests__control_flow__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/control_flow/baml_tests__control_flow__04_tir.snap @@ -121,13 +121,13 @@ function user.NestedControlFlow(x: int) -> string throws never { function user.TestClient$new() -> ? throws never { baml.llm.PrimitiveClient { name: "TestClient", provider: "openai", options: baml.llm.PrimitiveClientOptions { model: "gpt-4", base_url: "https://api.opena...", allowed_role_metadata: null, finish_reason_allow_list: null, finish_reason_deny_list: null, supports_streaming: null, default_role: "system", allowed_roles: ["system", "user", "assistant"], remap_roles: null, api_key: null, provider_options: null, headers: map { }, query_params: map { }, request_body: map { } } } : unknown } -function user.LlmFunction(input: string) -> string throws never { +function user.LlmFunction(input: string) -> string throws unknown { baml.llm.call_llm_function(TestClient, "LlmFunction", map { "input": input }) : string } -function user.LlmFunction$render_prompt(input: string) -> baml.llm.PromptAst throws never { +function user.LlmFunction$render_prompt(input: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(TestClient, "LlmFunction", map { "input": input }) : baml.llm.PromptAst } -function user.LlmFunction$build_request(input: string) -> baml.http.Request throws never { +function user.LlmFunction$build_request(input: string) -> baml.http.Request throws unknown { baml.llm.build_request(TestClient, "LlmFunction", map { "input": input }) : baml.http.Request } function user.LlmFunction$parse(json: string) -> string throws never { diff --git a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap index 0eed641ea9..c865e0a63c 100644 --- a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__04_tir.snap @@ -558,13 +558,13 @@ function user.TriviaLongValueClient$new() -> ? throws never { function user.TriviaLongValueClientBlock$new() -> ? throws never { baml.llm.PrimitiveClient { name: "TriviaLongValueCl...", provider: "openai", options: baml.llm.PrimitiveClientOptions { model: "some-extremely-lo...", base_url: "https://some-real...", allowed_role_metadata: null, finish_reason_allow_list: null, finish_reason_deny_list: null, supports_streaming: null, default_role: "system", allowed_roles: ["system", "user", "assistant"], remap_roles: null, api_key: null, provider_options: null, headers: map { }, query_params: map { }, request_body: map { } } } : unknown } -function user.TestTarget() -> string throws never { +function user.TestTarget() -> string throws unknown { baml.llm.call_llm_function(BasicClient, "TestTarget", map { }) : string } -function user.TestTarget$render_prompt() -> baml.llm.PromptAst throws never { +function user.TestTarget$render_prompt() -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(BasicClient, "TestTarget", map { }) : baml.llm.PromptAst } -function user.TestTarget$build_request() -> baml.http.Request throws never { +function user.TestTarget$build_request() -> baml.http.Request throws unknown { baml.llm.build_request(BasicClient, "TestTarget", map { }) : baml.http.Request } function user.TestTarget$parse(json: string) -> string throws never { @@ -870,72 +870,72 @@ function user.DeepNestLongExpr(a: int, b: int, c: int, d: int) -> int throws nev } } } -function user.LlmBasic(name: string) -> string throws never { +function user.LlmBasic(name: string) -> string throws unknown { baml.llm.call_llm_function(GPT4, "LlmBasic", map { "name": name }) : string !! 11324..11381: unresolved name: GPT4 } -function user.LlmBasic$render_prompt(name: string) -> baml.llm.PromptAst throws never { +function user.LlmBasic$render_prompt(name: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "LlmBasic", map { "name": name }) : baml.llm.PromptAst !! 11247..11381: unresolved name: GPT4 } -function user.LlmBasic$build_request(name: string) -> baml.http.Request throws never { +function user.LlmBasic$build_request(name: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "LlmBasic", map { "name": name }) : baml.http.Request !! 11247..11381: unresolved name: GPT4 } function user.LlmBasic$parse(json: string) -> string throws never { baml.llm.parse("LlmBasic", json) : string } -function user.LlmMultiLine(text: string) -> string throws never { +function user.LlmMultiLine(text: string) -> string throws unknown { baml.llm.call_llm_function(GPT4, "LlmMultiLine", map { "text": text }) : string !! 11484..11606: unresolved name: GPT4 } -function user.LlmMultiLine$render_prompt(text: string) -> baml.llm.PromptAst throws never { +function user.LlmMultiLine$render_prompt(text: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "LlmMultiLine", map { "text": text }) : baml.llm.PromptAst !! 11381..11606: unresolved name: GPT4 } -function user.LlmMultiLine$build_request(text: string) -> baml.http.Request throws never { +function user.LlmMultiLine$build_request(text: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "LlmMultiLine", map { "text": text }) : baml.http.Request !! 11381..11606: unresolved name: GPT4 } function user.LlmMultiLine$parse(json: string) -> string throws never { baml.llm.parse("LlmMultiLine", json) : string } -function user.LlmStringClient(text: string) -> string throws never { +function user.LlmStringClient(text: string) -> string throws unknown { baml.llm.call_llm_function(baml.llm.Client { name: "openai/gpt-4o", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "LlmStringClient", map { "text": text }) : string } -function user.LlmStringClient$render_prompt(text: string) -> baml.llm.PromptAst throws never { +function user.LlmStringClient$render_prompt(text: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(baml.llm.Client { name: "openai/gpt-4o", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "LlmStringClient", map { "text": text }) : baml.llm.PromptAst } -function user.LlmStringClient$build_request(text: string) -> baml.http.Request throws never { +function user.LlmStringClient$build_request(text: string) -> baml.http.Request throws unknown { baml.llm.build_request(baml.llm.Client { name: "openai/gpt-4o", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "LlmStringClient", map { "text": text }) : baml.http.Request } function user.LlmStringClient$parse(json: string) -> string throws never { baml.llm.parse("LlmStringClient", json) : string } -function user.LlmReversedOrder(text: string) -> string throws never { +function user.LlmReversedOrder(text: string) -> string throws unknown { baml.llm.call_llm_function(GPT4, "LlmReversedOrder", map { "text": text }) : string !! 11884..11941: unresolved name: GPT4 } -function user.LlmReversedOrder$render_prompt(text: string) -> baml.llm.PromptAst throws never { +function user.LlmReversedOrder$render_prompt(text: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "LlmReversedOrder", map { "text": text }) : baml.llm.PromptAst !! 11771..11941: unresolved name: GPT4 } -function user.LlmReversedOrder$build_request(text: string) -> baml.http.Request throws never { +function user.LlmReversedOrder$build_request(text: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "LlmReversedOrder", map { "text": text }) : baml.http.Request !! 11771..11941: unresolved name: GPT4 } function user.LlmReversedOrder$parse(json: string) -> string throws never { baml.llm.parse("LlmReversedOrder", json) : string } -function user.LlmWithLongSignature(first_context: string, second_context: string, third_context: string, user_query: string) -> string throws never { +function user.LlmWithLongSignature(first_context: string, second_context: string, third_context: string, user_query: string) -> string throws unknown { baml.llm.call_llm_function(GPT4, "LlmWithLongSignature", map { "first_context": first_context, "second_context": second_context, "third_context": third_context, "user_query": user_query }) : string !! 12126..12277: unresolved name: GPT4 } -function user.LlmWithLongSignature$render_prompt(first_context: string, second_context: string, third_context: string, user_query: string) -> baml.llm.PromptAst throws never { +function user.LlmWithLongSignature$render_prompt(first_context: string, second_context: string, third_context: string, user_query: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "LlmWithLongSignature", map { "first_context": first_context, "second_context": second_context, "third_context": third_context, "user_query": user_query }) : baml.llm.PromptAst !! 11941..12277: unresolved name: GPT4 } -function user.LlmWithLongSignature$build_request(first_context: string, second_context: string, third_context: string, user_query: string) -> baml.http.Request throws never { +function user.LlmWithLongSignature$build_request(first_context: string, second_context: string, third_context: string, user_query: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "LlmWithLongSignature", map { "first_context": first_context, "second_context": second_context, "third_context": third_context, "user_query": user_query }) : baml.http.Request !! 11941..12277: unresolved name: GPT4 } diff --git a/baml_language/crates/baml_tests/snapshots/generator/baml_tests__generator__04_tir.snap b/baml_language/crates/baml_tests/snapshots/generator/baml_tests__generator__04_tir.snap index d96bed1bd3..05b7073c30 100644 --- a/baml_language/crates/baml_tests/snapshots/generator/baml_tests__generator__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/generator/baml_tests__generator__04_tir.snap @@ -6,15 +6,15 @@ class user.MyClass { name: string age: int } -function user.GetUser() -> user.MyClass throws never { +function user.GetUser() -> user.MyClass throws unknown { baml.llm.call_llm_function(null, "GetUser", map { }) : user.MyClass !! 347..397: type mismatch: expected baml.llm.Client, got null } -function user.GetUser$render_prompt() -> baml.llm.PromptAst throws never { +function user.GetUser$render_prompt() -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(null, "GetUser", map { }) : baml.llm.PromptAst !! 316..397: type mismatch: expected baml.llm.Client, got null } -function user.GetUser$build_request() -> baml.http.Request throws never { +function user.GetUser$build_request() -> baml.http.Request throws unknown { baml.llm.build_request(null, "GetUser", map { }) : baml.http.Request !! 316..397: type mismatch: expected baml.llm.Client, got null } diff --git a/baml_language/crates/baml_tests/snapshots/header_in_llm_function/baml_tests__header_in_llm_function__04_tir.snap b/baml_language/crates/baml_tests/snapshots/header_in_llm_function/baml_tests__header_in_llm_function__04_tir.snap index 9bcc01c802..2d093a3b18 100644 --- a/baml_language/crates/baml_tests/snapshots/header_in_llm_function/baml_tests__header_in_llm_function__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/header_in_llm_function/baml_tests__header_in_llm_function__04_tir.snap @@ -8,13 +8,13 @@ class user.Hi { function user.VertexGCP$new() -> ? throws never { baml.llm.PrimitiveClient { name: "VertexGCP", provider: "vertex-ai", options: baml.llm.PrimitiveClientOptions { model: "gemini-2", base_url: null, allowed_role_metadata: null, finish_reason_allow_list: null, finish_reason_deny_list: null, supports_streaming: null, default_role: "user", allowed_roles: ["system", "user", "assistant"], remap_roles: map { "assistant": "model" }, api_key: null, provider_options: null, headers: map { }, query_params: map { }, request_body: map { "flash": "" } } } : unknown } -function user.FetchDatax(user_id: string) -> user.Hi throws never { +function user.FetchDatax(user_id: string) -> user.Hi throws unknown { baml.llm.call_llm_function(VertexGCP, "FetchDatax", map { "user_id": user_id }) : user.Hi } -function user.FetchDatax$render_prompt(user_id: string) -> baml.llm.PromptAst throws never { +function user.FetchDatax$render_prompt(user_id: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(VertexGCP, "FetchDatax", map { "user_id": user_id }) : baml.llm.PromptAst } -function user.FetchDatax$build_request(user_id: string) -> baml.http.Request throws never { +function user.FetchDatax$build_request(user_id: string) -> baml.http.Request throws unknown { baml.llm.build_request(VertexGCP, "FetchDatax", map { "user_id": user_id }) : baml.http.Request } function user.FetchDatax$parse(json: string) -> user.Hi throws never { diff --git a/baml_language/crates/baml_tests/snapshots/parser_speculative/baml_tests__parser_speculative__04_tir.snap b/baml_language/crates/baml_tests/snapshots/parser_speculative/baml_tests__parser_speculative__04_tir.snap index 5631c1a2ca..316e0eaf2a 100644 --- a/baml_language/crates/baml_tests/snapshots/parser_speculative/baml_tests__parser_speculative__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/parser_speculative/baml_tests__parser_speculative__04_tir.snap @@ -2,30 +2,30 @@ source: crates/baml_tests/src/generated_tests.rs --- === TIR2 === -function user.Ambiguous(input: string) -> string throws never { +function user.Ambiguous(input: string) -> string throws unknown { baml.llm.call_llm_function(GPT4, "Ambiguous", map { "input": input }) : string !! 89..220: unresolved name: GPT4 } -function user.Ambiguous$render_prompt(input: string) -> baml.llm.PromptAst throws never { +function user.Ambiguous$render_prompt(input: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "Ambiguous", map { "input": input }) : baml.llm.PromptAst !! 0..220: unresolved name: GPT4 } -function user.Ambiguous$build_request(input: string) -> baml.http.Request throws never { +function user.Ambiguous$build_request(input: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "Ambiguous", map { "input": input }) : baml.http.Request !! 0..220: unresolved name: GPT4 } function user.Ambiguous$parse(json: string) -> string throws never { baml.llm.parse("Ambiguous", json) : string } -function user.AnotherLLM(input: string) -> string throws never { +function user.AnotherLLM(input: string) -> string throws unknown { baml.llm.call_llm_function(null, "AnotherLLM", map { "input": input }) : string !! 309..421: type mismatch: expected baml.llm.Client, got null } -function user.AnotherLLM$render_prompt(input: string) -> baml.llm.PromptAst throws never { +function user.AnotherLLM$render_prompt(input: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(null, "AnotherLLM", map { "input": input }) : baml.llm.PromptAst !! 220..421: type mismatch: expected baml.llm.Client, got null } -function user.AnotherLLM$build_request(input: string) -> baml.http.Request throws never { +function user.AnotherLLM$build_request(input: string) -> baml.http.Request throws unknown { baml.llm.build_request(null, "AnotherLLM", map { "input": input }) : baml.http.Request !! 220..421: type mismatch: expected baml.llm.Client, got null } @@ -50,30 +50,30 @@ class user.JsonData$stream { field1: null | int field2: null | int } -function user.ExtractData(text: string) -> user.JsonData throws never { +function user.ExtractData(text: string) -> user.JsonData throws unknown { baml.llm.call_llm_function(GPT4, "ExtractData", map { "text": text }) : user.JsonData !! 46..146: unresolved name: GPT4 } -function user.ExtractData$render_prompt(text: string) -> baml.llm.PromptAst throws never { +function user.ExtractData$render_prompt(text: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "ExtractData", map { "text": text }) : baml.llm.PromptAst !! 0..146: unresolved name: GPT4 } -function user.ExtractData$build_request(text: string) -> baml.http.Request throws never { +function user.ExtractData$build_request(text: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "ExtractData", map { "text": text }) : baml.http.Request !! 0..146: unresolved name: GPT4 } function user.ExtractData$parse(json: string) -> user.JsonData throws never { baml.llm.parse("ExtractData", json) : user.JsonData } -function user.LLMAnalyze(text: string) -> user.Analysis throws never { +function user.LLMAnalyze(text: string) -> user.Analysis throws unknown { baml.llm.call_llm_function(GPT4, "LLMAnalyze", map { "text": text }) : user.Analysis !! 96..242: unresolved name: GPT4 } -function user.LLMAnalyze$render_prompt(text: string) -> baml.llm.PromptAst throws never { +function user.LLMAnalyze$render_prompt(text: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "LLMAnalyze", map { "text": text }) : baml.llm.PromptAst !! 0..242: unresolved name: GPT4 } -function user.LLMAnalyze$build_request(text: string) -> baml.http.Request throws never { +function user.LLMAnalyze$build_request(text: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "LLMAnalyze", map { "text": text }) : baml.http.Request !! 0..242: unresolved name: GPT4 } @@ -108,15 +108,15 @@ function user.ProcessAnalysis(analysis: user.Analysis) -> string throws never { } } } -function user.ChainedProcess(input: string) -> string throws never { +function user.ChainedProcess(input: string) -> string throws unknown { baml.llm.call_llm_function(GPT4, "ChainedProcess", map { "input": input }) : string !! 724..860: unresolved name: GPT4 } -function user.ChainedProcess$render_prompt(input: string) -> baml.llm.PromptAst throws never { +function user.ChainedProcess$render_prompt(input: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "ChainedProcess", map { "input": input }) : baml.llm.PromptAst !! 674..860: unresolved name: GPT4 } -function user.ChainedProcess$build_request(input: string) -> baml.http.Request throws never { +function user.ChainedProcess$build_request(input: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "ChainedProcess", map { "input": input }) : baml.http.Request !! 674..860: unresolved name: GPT4 } diff --git a/baml_language/crates/baml_tests/snapshots/parser_strings/baml_tests__parser_strings__04_tir.snap b/baml_language/crates/baml_tests/snapshots/parser_strings/baml_tests__parser_strings__04_tir.snap index cda9b2d5ea..3e765e7940 100644 --- a/baml_language/crates/baml_tests/snapshots/parser_strings/baml_tests__parser_strings__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/parser_strings/baml_tests__parser_strings__04_tir.snap @@ -14,15 +14,15 @@ class user.NestedQuotes { complex_nesting: string @description("It's \"really\" 'complex'") json_like: string @alias("{\"key\": \"value\"}") } -function user.FormatMessage(name: string) -> string throws never { +function user.FormatMessage(name: string) -> string throws unknown { baml.llm.call_llm_function(GPT4, "FormatMessage", map { "name": name }) : string !! 428..553: unresolved name: GPT4 } -function user.FormatMessage$render_prompt(name: string) -> baml.llm.PromptAst throws never { +function user.FormatMessage$render_prompt(name: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "FormatMessage", map { "name": name }) : baml.llm.PromptAst !! 345..553: unresolved name: GPT4 } -function user.FormatMessage$build_request(name: string) -> baml.http.Request throws never { +function user.FormatMessage$build_request(name: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "FormatMessage", map { "name": name }) : baml.http.Request !! 345..553: unresolved name: GPT4 } @@ -51,15 +51,15 @@ class user.RawStrings { and #hashes# "###) } -function user.ProcessText(input: string) -> string throws never { +function user.ProcessText(input: string) -> string throws unknown { baml.llm.call_llm_function(GPT4, "ProcessText", map { "input": input }) : string !! 411..532: unresolved name: GPT4 } -function user.ProcessText$render_prompt(input: string) -> baml.llm.PromptAst throws never { +function user.ProcessText$render_prompt(input: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "ProcessText", map { "input": input }) : baml.llm.PromptAst !! 364..532: unresolved name: GPT4 } -function user.ProcessText$build_request(input: string) -> baml.http.Request throws never { +function user.ProcessText$build_request(input: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "ProcessText", map { "input": input }) : baml.http.Request !! 364..532: unresolved name: GPT4 } diff --git a/baml_language/crates/baml_tests/snapshots/pending_greaters_fix/baml_tests__pending_greaters_fix__04_tir.snap b/baml_language/crates/baml_tests/snapshots/pending_greaters_fix/baml_tests__pending_greaters_fix__04_tir.snap index 394951ebd0..1766f63ad7 100644 --- a/baml_language/crates/baml_tests/snapshots/pending_greaters_fix/baml_tests__pending_greaters_fix__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/pending_greaters_fix/baml_tests__pending_greaters_fix__04_tir.snap @@ -11,30 +11,30 @@ class user.NestedGenerics { nested: map> simple: map } -function user.BadParam(x: map) -> string throws never { +function user.BadParam(x: map) -> string throws unknown { baml.llm.call_llm_function(GPT4, "BadParam", map { "x": x }) : string !! 741..777: unresolved name: GPT4 } -function user.BadParam$render_prompt(x: map) -> baml.llm.PromptAst throws never { +function user.BadParam$render_prompt(x: map) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "BadParam", map { "x": x }) : baml.llm.PromptAst !! 655..777: unresolved name: GPT4 } -function user.BadParam$build_request(x: map) -> baml.http.Request throws never { +function user.BadParam$build_request(x: map) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "BadParam", map { "x": x }) : baml.http.Request !! 655..777: unresolved name: GPT4 } function user.BadParam$parse(json: string) -> string throws never { baml.llm.parse("BadParam", json) : string } -function user.BadReturnType() -> map throws never { +function user.BadReturnType() -> map throws unknown { baml.llm.call_llm_function(GPT4, "BadReturnType", map { }) : map !! 861..897: unresolved name: GPT4 } -function user.BadReturnType$render_prompt() -> baml.llm.PromptAst throws never { +function user.BadReturnType$render_prompt() -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "BadReturnType", map { }) : baml.llm.PromptAst !! 777..897: unresolved name: GPT4 } -function user.BadReturnType$build_request() -> baml.http.Request throws never { +function user.BadReturnType$build_request() -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "BadReturnType", map { }) : baml.http.Request !! 777..897: unresolved name: GPT4 } diff --git a/baml_language/crates/baml_tests/snapshots/simple_function/baml_tests__simple_function__04_tir.snap b/baml_language/crates/baml_tests/snapshots/simple_function/baml_tests__simple_function__04_tir.snap index b9fd0c021d..3a688c52f0 100644 --- a/baml_language/crates/baml_tests/snapshots/simple_function/baml_tests__simple_function__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/simple_function/baml_tests__simple_function__04_tir.snap @@ -2,15 +2,15 @@ source: crates/baml_tests/src/generated_tests.rs --- === TIR2 === -function user.HelloWorld(name: string) -> string throws never { +function user.HelloWorld(name: string) -> string throws unknown { baml.llm.call_llm_function(GPT4, "HelloWorld", map { "name": name }) : string !! 43..104: unresolved name: GPT4 } -function user.HelloWorld$render_prompt(name: string) -> baml.llm.PromptAst throws never { +function user.HelloWorld$render_prompt(name: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(GPT4, "HelloWorld", map { "name": name }) : baml.llm.PromptAst !! 0..104: unresolved name: GPT4 } -function user.HelloWorld$build_request(name: string) -> baml.http.Request throws never { +function user.HelloWorld$build_request(name: string) -> baml.http.Request throws unknown { baml.llm.build_request(GPT4, "HelloWorld", map { "name": name }) : baml.http.Request !! 0..104: unresolved name: GPT4 } diff --git a/baml_language/crates/baml_tests/snapshots/type_builder_errors/baml_tests__type_builder_errors__04_tir.snap b/baml_language/crates/baml_tests/snapshots/type_builder_errors/baml_tests__type_builder_errors__04_tir.snap index 8027cbefea..9a64aeedf2 100644 --- a/baml_language/crates/baml_tests/snapshots/type_builder_errors/baml_tests__type_builder_errors__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/type_builder_errors/baml_tests__type_builder_errors__04_tir.snap @@ -2,13 +2,13 @@ source: crates/baml_tests/src/generated_tests.rs --- === TIR2 === -function user.TypeBuilderFn(from_text: string) -> user.Resume throws never { +function user.TypeBuilderFn(from_text: string) -> user.Resume throws unknown { baml.llm.call_llm_function(baml.llm.Client { name: "openai/gpt-4o-mini", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "TypeBuilderFn", map { "from_text": from_text }) : user.Resume } -function user.TypeBuilderFn$render_prompt(from_text: string) -> baml.llm.PromptAst throws never { +function user.TypeBuilderFn$render_prompt(from_text: string) -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(baml.llm.Client { name: "openai/gpt-4o-mini", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "TypeBuilderFn", map { "from_text": from_text }) : baml.llm.PromptAst } -function user.TypeBuilderFn$build_request(from_text: string) -> baml.http.Request throws never { +function user.TypeBuilderFn$build_request(from_text: string) -> baml.http.Request throws unknown { baml.llm.build_request(baml.llm.Client { name: "openai/gpt-4o-mini", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "TypeBuilderFn", map { "from_text": from_text }) : baml.http.Request } function user.TypeBuilderFn$parse(json: string) -> user.Resume throws never { diff --git a/baml_language/crates/baml_tests/snapshots/type_builder_test/baml_tests__type_builder_test__04_tir.snap b/baml_language/crates/baml_tests/snapshots/type_builder_test/baml_tests__type_builder_test__04_tir.snap index c388b23762..b61927152e 100644 --- a/baml_language/crates/baml_tests/snapshots/type_builder_test/baml_tests__type_builder_test__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/type_builder_test/baml_tests__type_builder_test__04_tir.snap @@ -2,13 +2,13 @@ source: crates/baml_tests/src/generated_tests.rs --- === TIR2 === -function user.Fn() -> string throws never { +function user.Fn() -> string throws unknown { baml.llm.call_llm_function(baml.llm.Client { name: "openai/gpt-4o", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "Fn", map { }) : string } -function user.Fn$render_prompt() -> baml.llm.PromptAst throws never { +function user.Fn$render_prompt() -> baml.llm.PromptAst throws unknown { baml.llm.render_prompt(baml.llm.Client { name: "openai/gpt-4o", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "Fn", map { }) : baml.llm.PromptAst } -function user.Fn$build_request() -> baml.http.Request throws never { +function user.Fn$build_request() -> baml.http.Request throws unknown { baml.llm.build_request(baml.llm.Client { name: "openai/gpt-4o", client_type: baml.llm.ClientType.Primitive, sub_clients: [], retry: null, counter: 0 }, "Fn", map { }) : baml.http.Request } function user.Fn$parse(json: string) -> string throws never { From 16f2a56e47b245f513f73f6bba84fb41e19747d6 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 17:25:23 -0500 Subject: [PATCH 22/26] Show declared and inferred throws in LSP function hover Add throws field to TypeInfo::Function populated from either inferred transitive throw sets (via function_throw_sets TIR query) or the declared throws clause from the function signature. Hover now shows e.g. `function use_handler(h: Handler) -> null throws string` for functions with inferred or declared throws. Functions that throw never omit the throws clause entirely. Add 4 new LSP hover tests covering: no-throws omission, declared throws, inferred throws from member calls, and union throws. --- .../crates/baml_lsp2_actions/src/type_info.rs | 50 ++++++++- .../baml_lsp2_actions/src/type_info_tests.rs | 105 ++++++++++++++++++ 2 files changed, 152 insertions(+), 3 deletions(-) diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info.rs b/baml_language/crates/baml_lsp2_actions/src/type_info.rs index 610bcbdcb5..bbdebee5f1 100644 --- a/baml_language/crates/baml_lsp2_actions/src/type_info.rs +++ b/baml_language/crates/baml_lsp2_actions/src/type_info.rs @@ -52,11 +52,14 @@ use crate::{Db, utils}; /// different output contexts (markdown, plain text, etc.). #[derive(Debug, Clone, PartialEq, Eq)] pub enum TypeInfo { - /// A function definition: name, parameters (name + type string), return type. + /// A function definition: name, parameters (name + type string), return type, throws. Function { name: String, params: Vec<(String, String)>, return_type: Option, + /// `None` = throws never (omitted from display). + /// `Some("string")` = declared or inferred throws. + throws: Option, }, /// A class definition: name, fields (name + type string). Class { @@ -85,6 +88,7 @@ impl TypeInfo { name, params, return_type, + throws, } => { let param_strs: Vec = params.iter().map(|(n, t)| format!("{n}: {t}")).collect(); @@ -92,11 +96,16 @@ impl TypeInfo { .as_deref() .map(|r| format!(" -> {r}")) .unwrap_or_default(); + let throws_str = throws + .as_deref() + .map(|t| format!(" throws {t}")) + .unwrap_or_default(); format!( - "```baml\nfunction {}({}){}\n```", + "```baml\nfunction {}({}){}{}\n```", name, param_strs.join(", "), - ret + ret, + throws_str, ) } TypeInfo::Class { name, fields } => { @@ -216,10 +225,17 @@ fn type_info_for_definition(db: &dyn Db, def: Definition<'_>) -> TypeInfo { }) .collect(); let return_type = sig.return_type.as_ref().map(utils::display_type_expr); + + // Compute throws: prefer inferred/transitive throws from TIR, fall back + // to declared throws from the signature. + let throws = inferred_throws_for_function(db, func_loc, func_data) + .or_else(|| sig.throws.as_ref().map(utils::display_type_expr)); + TypeInfo::Function { name: func_name, params, return_type, + throws, } } @@ -342,6 +358,34 @@ fn type_info_for_definition(db: &dyn Db, def: Definition<'_>) -> TypeInfo { } } +// ── inferred throws ────────────────────────────────────────────────────────── + +/// Get the inferred (transitive) throws for a function, rendered as a display string. +/// +/// Returns `None` when the function throws `never` (empty throw set). +fn inferred_throws_for_function( + db: &dyn Db, + func_loc: baml_compiler2_hir::loc::FunctionLoc<'_>, + func_data: &baml_compiler2_hir::item_tree::Function, +) -> Option { + let pkg_info = baml_compiler2_hir::file_package::file_package(db, func_loc.file(db)); + let pkg_id = baml_compiler2_hir::package::PackageId::new(db, pkg_info.package); + let throw_sets = baml_compiler2_tir::throw_inference::function_throw_sets(db, pkg_id); + + let key = baml_compiler2_tir::throw_inference::throw_set_key( + &pkg_info.namespace_path, + &func_data.name, + ); + + let facts = throw_sets.transitive_for(&key)?; + if facts.is_empty() { + return None; + } + + let parts: Vec = facts.iter().map(utils::display_ty).collect(); + Some(parts.join(" | ")) +} + // ── local_type_info ─────────────────────────────────────────────────────────── /// Build `TypeInfo::LocalVar` for a local variable (let binding or parameter). diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs index 238c550bd8..568c0eae9d 100644 --- a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs +++ b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs @@ -253,4 +253,109 @@ function <[CURSOR]many_handlers(hs: Handler[]) -> null { "Param type should render as Handler[], got: {md}" ); } + + // ── Throws display in hover ───────────────────────────────────────────── + + #[test] + fn hover_function_no_throws_omits_throws_clause() { + let test = CursorTest::new( + r#" +function <[CURSOR]safe_fn(x: int) -> string { + "hello" +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert_eq!( + throws, &None, + "Non-throwing function should have throws: None" + ); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + !md.contains("throws"), + "Non-throwing function hover should not contain 'throws', got: {md}" + ); + } + + #[test] + fn hover_function_with_declared_throws() { + let test = CursorTest::new( + r#" +function <[CURSOR]risky_fn(x: int) -> string throws string { + "hello" +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert!( + throws.is_some(), + "Function with declared throws should have throws: Some(...)" + ); + assert_eq!(throws.as_deref(), Some("string")); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + md.contains("throws string"), + "Hover should contain 'throws string', got: {md}" + ); + } + + #[test] + fn hover_function_with_inferred_throws_from_member_call() { + // use_handler calls h.run() which throws string — the inferred throws + // should propagate and appear in the hover. + let test = CursorTest::new( + r#" +class Handler { + run: () -> null throws E +} + +function <[CURSOR]use_handler(h: Handler) -> null { + h.run() +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert!( + throws.is_some(), + "Function with inferred throws should have throws: Some(...)" + ); + assert_eq!(throws.as_deref(), Some("string")); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + md.contains("throws string"), + "Hover should show inferred 'throws string', got: {md}" + ); + } + + #[test] + fn hover_function_with_declared_union_throws() { + let test = CursorTest::new( + r#" +function <[CURSOR]multi_throw(x: int) -> string throws string | int { + "hello" +} +"#, + ); + let info = test.type_info().expect("should resolve"); + let md = info.to_hover_markdown(); + assert!( + md.contains("throws"), + "Hover should contain throws clause, got: {md}" + ); + } } From 0795fdb654889b7f63d8c53977089869c7884c1d Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 9 Apr 2026 18:52:34 -0500 Subject: [PATCH 23/26] Refactor throws propagation and callable resolution --- .../crates/baml_compiler2_tir/src/builder.rs | 472 +++--------- .../src/effective_throws.rs | 4 +- .../src/package_interface.rs | 575 ++++++++++++++- .../crates/baml_compiler2_tir/src/resolve.rs | 37 +- .../baml_compiler2_tir/src/throw_inference.rs | 669 ++++++++++++++---- .../baml_lsp2_actions/src/definition.rs | 118 ++- .../crates/baml_lsp2_actions/src/type_info.rs | 61 +- .../baml_lsp2_actions/src/type_info_tests.rs | 64 +- .../nested_function_throws_validation.baml | 8 +- .../generic_stored_callback.baml | 27 + .../heterogeneous_collection.baml | 39 + ...ws__01_lexer__generic_stored_callback.snap | 102 +++ ...s__01_lexer__heterogeneous_collection.snap | 172 +++++ ...s__02_parser__generic_stored_callback.snap | 163 +++++ ...__02_parser__heterogeneous_collection.snap | 259 +++++++ ...l_tests__function_type_throws__03_hir.snap | 36 + ...tests__function_type_throws__04_5_mir.snap | 386 +++++++++- ...l_tests__function_type_throws__04_tir.snap | 128 +++- ..._function_type_throws__05_diagnostics.snap | 260 +++++++ ...sts__function_type_throws__06_codegen.snap | 192 ++++- 20 files changed, 3161 insertions(+), 611 deletions(-) diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index 5bc50e5443..4857289498 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -2871,12 +2871,16 @@ impl<'db> TypeInferenceBuilder<'db> { } else { first.clone() }; - let Some(pkg_items) = self.res_ctx.items_for_package(db, &pkg_name) else { - return Ty::Unknown { - attr: TyAttr::default(), - }; - }; + let own_pkg_name = self.package_id.name(db); + if pkg_name != own_pkg_name { + return self + .resolve_external_package_item(&pkg_name, &segments[1..], expr_id) + .unwrap_or(Ty::Unknown { + attr: TyAttr::default(), + }); + } + let pkg_items = self.package_items; if pkg_items.namespaces.is_empty() { return Ty::Unknown { attr: TyAttr::default(), @@ -2918,10 +2922,7 @@ impl<'db> TypeInferenceBuilder<'db> { let func_data_for_sig = &item_tree_for_func[func_loc.id(db)]; let generic_params = &func_data_for_sig.generic_params; let pkg_info = baml_compiler2_hir::file_package::file_package(db, func_loc.file(db)); - let function_key = crate::throw_inference::throw_set_key( - &pkg_info.namespace_path, - &func_data_for_sig.name, - ); + let function_key = crate::throw_inference::callable_throw_key(db, func_loc); let ns_context = pkg_info.namespace_path; self.resolutions.insert( expr_id, @@ -2968,6 +2969,51 @@ impl<'db> TypeInferenceBuilder<'db> { None } + fn resolve_external_package_item( + &mut self, + pkg_name: &Name, + path: &[Name], + expr_id: ExprId, + ) -> Option { + if path.is_empty() { + return None; + } + + let db = self.context.db(); + let mut full_path = Vec::with_capacity(path.len() + 1); + full_path.push(pkg_name.clone()); + full_path.extend_from_slice(path); + + if let Some((_source, function)) = self.res_ctx.lookup_function(db, &full_path, &[]) { + let item = path.last().expect("non-empty path"); + if let Some(Definition::Function(func_loc)) = self + .res_ctx + .lookup_value_definition_in_package(db, pkg_name, &path[..path.len() - 1], item) + { + self.resolutions.insert( + expr_id, + crate::inference::MemberResolution::Free { func_loc }, + ); + } + return Some(Ty::Function { + params: function + .params + .into_iter() + .map(|(n, ty)| (Some(n), ty)) + .collect(), + ret: Box::new(function.return_type), + throws: Box::new(function.outward_throws.unwrap_or(Ty::Never { + attr: TyAttr::default(), + })), + attr: TyAttr::default(), + }); + } + + self.res_ctx + .resolve_type(db, &full_path, &[]) + .map(|(_, ty)| ty) + } + /// Resolve a single name to its type. /// /// Checks local variables first, then value namespace (functions), then @@ -2992,8 +3038,7 @@ impl<'db> TypeInferenceBuilder<'db> { let sig_pkg = baml_compiler2_hir::file_package::file_package(db, func_loc.file(db)); let sig_ns = sig_pkg.namespace_path; - let function_key = - crate::throw_inference::throw_set_key(&sig_ns, &func_data.name); + let function_key = crate::throw_inference::callable_throw_key(db, func_loc); let func_scope = self.find_function_scope_id( func_loc.file(db), func_data.span, @@ -3379,46 +3424,22 @@ impl<'db> TypeInferenceBuilder<'db> { .collect() } - /// Fetch `PackageItems` for the package that owns a class type. - /// - /// Returns `self.package_items` when the class is in the current package, - /// or loads the correct foreign package's items when the class lives in a - /// declared dependency package. - fn resolve_class_pkg_items( - &self, - class_pkg: &baml_base::Name, - ) -> Option<&'db baml_compiler2_hir::package::PackageItems<'db>> { - let db = self.context.db(); - self.res_ctx.items_for_package(db, class_pkg) - } - - /// Resolve a `QualifiedTypeName` to a `ClassLoc` via `package_items` lookup. + /// Resolve a `QualifiedTypeName` to a `ClassLoc` via `PackageResolutionContext`. fn resolve_class_loc( &self, qtn: &crate::ty::QualifiedTypeName, ) -> Option> { - let pkg_items = self.resolve_class_pkg_items(qtn.package())?; - match pkg_items.lookup_type(qtn.namespace(), qtn.name())? { - Definition::Class(class_loc) => Some(class_loc), - _ => None, - } + let db = self.context.db(); + self.res_ctx.lookup_class_loc(db, qtn) } - /// Resolve a `QualifiedTypeName` to an `EnumLoc` via `package_items` lookup. + /// Resolve a `QualifiedTypeName` to an `EnumLoc` via `PackageResolutionContext`. fn resolve_enum_loc( &self, qtn: &crate::ty::QualifiedTypeName, ) -> Option> { let db = self.context.db(); - let items = if *qtn.package() == self.package_id.name(db) { - self.package_items - } else { - self.res_ctx.items_for_package(db, qtn.package())? - }; - match items.lookup_type(qtn.namespace(), qtn.name())? { - Definition::Enum(enum_loc) => Some(enum_loc), - _ => None, - } + self.res_ctx.lookup_enum_loc(db, qtn) } /// Look up a class method by name. @@ -3442,29 +3463,16 @@ impl<'db> TypeInferenceBuilder<'db> { baml_compiler2_hir::loc::FunctionLoc<'db>, )> { let db = self.context.db(); - let pkg_items_for_class = self.resolve_class_pkg_items(class_name.package())?; - let def = pkg_items_for_class.lookup_type(class_name.namespace(), class_name.name())?; - let Definition::Class(class_loc) = def else { - return None; - }; - let file = class_loc.file(db); - let item_tree = baml_compiler2_ppir::file_item_tree(db, file); - let class_data = &item_tree[class_loc.id(db)]; - - // Resolve the FunctionLoc for the method (needed for MemberResolution regardless of path). - let func_loc = class_data.methods.iter().find_map(|&method_id| { - let method_data = &item_tree[method_id]; - if method_data.name == *method_name { - Some(baml_compiler2_hir::loc::FunctionLoc::new( - db, file, method_id, - )) - } else { - None - } - })?; - let own_pkg_name = self.package_id.name(db); if *class_name.package() == own_pkg_name { + let class_loc = self.res_ctx.lookup_class_loc(db, class_name.qtn())?; + let file = class_loc.file(db); + let item_tree = baml_compiler2_ppir::file_item_tree(db, file); + let class_data = &item_tree[class_loc.id(db)]; + let func_loc = self + .res_ctx + .lookup_class_method_locs(db, class_name, method_name)? + .1; // Own-package path: build function type from signature with scope/throw-set support. let ns_context = baml_compiler2_hir::file_package::file_package(db, file).namespace_path; @@ -3478,15 +3486,12 @@ impl<'db> TypeInferenceBuilder<'db> { all_generic_params.extend(method_data.generic_params.iter().cloned()); let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); let class_ty = Ty::Class(class_name.clone(), TyAttr::default()); - let function_key = crate::throw_inference::throw_set_key( - &ns_context, - &Name::new(format!("{}.{}", class_name.name(), method_data.name)), - ); + let function_key = crate::throw_inference::callable_throw_key(db, func_loc); let method_scope = self.find_function_scope_id(file, method_data.span, &method_data.name); let method_body = baml_compiler2_hir::body::function_body(db, func_loc); let ty = self.build_function_ty_from_signature( - pkg_items_for_class, + self.package_items, &ns_context, &all_generic_params, sig.as_ref(), @@ -3512,6 +3517,9 @@ impl<'db> TypeInferenceBuilder<'db> { // Dep-package path: delegate type-level resolution to PackageResolutionContext, // which reads from the pre-resolved PackageInterface and applies class-level // generic substitution. + let (class_loc, func_loc) = + self.res_ctx + .lookup_class_method_locs(db, class_name, method_name)?; let resolved = self .res_ctx .lookup_class_method(db, class_name, method_name)?; @@ -3523,7 +3531,7 @@ impl<'db> TypeInferenceBuilder<'db> { .map(|(n, ty)| (Some(n), ty)) .collect(), ret: Box::new(resolved.function.return_type), - throws: Box::new(resolved.function.throws.unwrap_or(Ty::Never { + throws: Box::new(resolved.function.outward_throws.unwrap_or(Ty::Never { attr: TyAttr::default(), })), attr: TyAttr::default(), @@ -3605,27 +3613,17 @@ impl<'db> TypeInferenceBuilder<'db> { let baml_ns_shorthands: &[&str] = &[ "env", "sys", "http", "math", "fs", "net", "media", "llm", "errors", "unstable", ]; - let (pkg_items, item_path_owned): ( - &baml_compiler2_hir::package::PackageItems<'db>, - Vec, - ) = if let Some(items) = self.res_ctx.items_for_package(db, &resolved_pkg_name) { - // Found the package directly. - if items.namespaces.is_empty() { - return None; - } - let ip = segments[1..].to_vec(); - (items, ip) + let is_current_package = first.as_str() == "root" + || first.as_str() == self.package_id.name(db).as_str() + || resolved_pkg_name == self.package_id.name(db); + let (target_pkg_name, item_path_owned): (Name, Vec) = if is_current_package { + (resolved_pkg_name, segments[1..].to_vec()) } else if baml_ns_shorthands.contains(&first.as_str()) { - // `env.X` → treat as `baml.env.X`: look up in the `"baml"` package - // with the namespace prefix prepended to the item path. - let baml_name = Name::new("baml"); - let baml_items = self.res_ctx.items_for_package(db, &baml_name)?; - // Prepend the namespace segment (`first`) to the item path. let mut ip = vec![first.clone()]; ip.extend_from_slice(&segments[1..]); - (baml_items, ip) + (Name::new("baml"), ip) } else { - return None; + (resolved_pkg_name, segments[1..].to_vec()) }; let item_path: &[Name] = &item_path_owned; @@ -3659,34 +3657,12 @@ impl<'db> TypeInferenceBuilder<'db> { // When there is a non-empty item_path, try class/enum member resolution // first (e.g. `baml.Array.length` → Array class, then method "length"). if !item_path.is_empty() { - let item_name = item_path.last().expect("non-empty item_path"); - if let Some(def) = pkg_items.lookup_type(&item_path[..item_path.len() - 1], item_name) { - match def { - Definition::Class(_class_loc) => { - if first.as_str() == "root" - || first.as_str() == self.package_id.name(db).as_str() - { - let class_qtn = crate::lower_type_expr::qualify_def(db, def, item_name); - let base_ty = Ty::Class(class_qtn.into(), TyAttr::default()); - return Some(self.resolve_member(&base_ty, member, at)); - } - let class_path: Vec<&str> = - item_path.iter().map(smol_str::SmolStr::as_str).collect(); - return self - .resolve_builtin_member(&class_path, &[], member, at) - .or_else(|| { - let class_qtn = - crate::lower_type_expr::qualify_def(db, def, item_name); - let base_ty = Ty::Class(class_qtn.into(), TyAttr::default()); - Some(self.resolve_member(&base_ty, member, at)) - }); - } - Definition::Enum(_) => { - let enum_qtn = crate::lower_type_expr::qualify_def(db, def, item_name); - let base_ty = Ty::Enum(enum_qtn.into(), TyAttr::default()); - return Some(self.resolve_member(&base_ty, member, at)); - } - _ => {} + let mut full_base_path = Vec::with_capacity(item_path.len() + 1); + full_base_path.push(target_pkg_name.clone()); + full_base_path.extend_from_slice(item_path); + if let Some((_source, base_ty)) = self.res_ctx.resolve_type(db, &full_base_path, &[]) { + if matches!(base_ty, Ty::Class(..) | Ty::Enum(..)) { + return Some(self.resolve_member(&base_ty, member, at)); } } } @@ -3698,7 +3674,11 @@ impl<'db> TypeInferenceBuilder<'db> { .chain(std::iter::once(member)) .cloned() .collect(); - let result = self.resolve_package_item(pkg_items, &full_path, at); + let result = if target_pkg_name == self.package_id.name(db) { + self.resolve_package_item(self.package_items, &full_path, at) + } else { + self.resolve_external_package_item(&target_pkg_name, &full_path, at) + }; if result.is_none() { // Package was found but the member doesn't exist — report a clear error // with the full dotted path as context (e.g. "unresolved member: testing.Quorum"). @@ -3747,16 +3727,9 @@ impl<'db> TypeInferenceBuilder<'db> { /// Resolve a method or field on a builtin class declared in the `"baml"` package. /// - /// 1. Fetches `package_items(db, "baml")`. - /// 2. Looks up `class_name` in the root namespace. - /// 3. Binds the class's `generic_params` to `type_args` (e.g. `{T → int}`). - /// 4. Searches the class methods for `member_name`, lowering the method's - /// parameter and return types with type variable substitution applied. - /// 5. Falls back to checking class fields. - /// - /// Returns `None` if the class or member is not found. - /// Wrapper around `resolve_builtin_method` that also stores a `MemberResolution` - /// when the result is a method (not a field). + /// This delegates the actual foreign-package lookup to + /// `PackageResolutionContext`, then records a `MemberResolution` when the + /// resolved member is a method. fn resolve_builtin_member( &mut self, class_path: &[&str], @@ -3791,226 +3764,24 @@ impl<'db> TypeInferenceBuilder<'db> { member_name: &Name, ) -> Option> { let db = self.context.db(); - let baml_items = self - .res_ctx - .items_for_package(db, &baml_base::Name::new("baml"))?; - - // Look up the class by path (e.g. &["Array"] or &["media", "Image"]). let path: Vec = class_path.iter().map(baml_base::Name::new).collect(); - let item = path.last().expect("non-empty class_path"); - let def = baml_items.lookup_type(&path[..path.len() - 1], item)?; - let baml_compiler2_hir::contributions::Definition::Class(class_loc) = def else { - return None; - }; - - let file = class_loc.file(db); - let stub_pkg = baml_compiler2_hir::file_package::file_package(db, file); - let stub_ns: &[Name] = &stub_pkg.namespace_path; - let item_tree = baml_compiler2_ppir::file_item_tree(db, file); - let class_data = &item_tree[class_loc.id(db)]; - - // Bind generic type variables: e.g. {T → int} for Array. - let mut bindings = crate::generics::bind_type_vars(&class_data.generic_params, type_args); - - // Search methods first. - for &method_id in &class_data.methods { - let method_data = &item_tree[method_id]; - if method_data.name == *member_name { - // Add method-level generics as TypeVar entries so they survive - // lowering and can be resolved by call-site inference. - for gp in &method_data.generic_params { - bindings - .entry(gp.clone()) - .or_insert_with(|| Ty::TypeVar(gp.clone(), TyAttr::default())); - } - let func_loc = baml_compiler2_hir::loc::FunctionLoc::new(db, file, method_id); - let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); - let mut diags = Vec::new(); - // Build the class type for self parameter resolution. - // For generics, apply type_args (e.g. Array). - let builtin_class_ty = if type_args.is_empty() { - let pkg_info = baml_compiler2_hir::file_package::file_package(db, file); - Ty::Class( - crate::ty::QualifiedTypeName::new_with_generic_params( - pkg_info.package, - pkg_info.namespace_path, - class_data.name.clone(), - class_data.generic_params.clone(), - ) - .into(), - TyAttr::default(), - ) - } else if type_args.len() == 1 { - // Single type arg: Array → List(T), special-case common containers - match class_data.name.as_str() { - "Array" => Ty::List(Box::new(type_args[0].clone()), TyAttr::default()), - _ => { - let pkg_info = baml_compiler2_hir::file_package::file_package(db, file); - Ty::Class( - crate::ty::QualifiedTypeName::new( - pkg_info.package, - pkg_info.namespace_path, - class_data.name.clone(), - ) - .into(), - TyAttr::default(), - ) - } - } - } else if type_args.len() == 2 { - match class_data.name.as_str() { - "Map" => Ty::Map( - Box::new(type_args[0].clone()), - Box::new(type_args[1].clone()), - TyAttr::default(), - ), - _ => { - let pkg_info = baml_compiler2_hir::file_package::file_package(db, file); - Ty::Class( - crate::ty::QualifiedTypeName::new( - pkg_info.package, - pkg_info.namespace_path, - class_data.name.clone(), - ) - .into(), - TyAttr::default(), - ) - } - } - } else { - let pkg_info = baml_compiler2_hir::file_package::file_package(db, file); - Ty::Class( - crate::ty::QualifiedTypeName::new( - pkg_info.package, - pkg_info.namespace_path, - class_data.name.clone(), - ) - .into(), - TyAttr::default(), - ) - }; - - // Collect generic params for this method (class + method generics). - let method_generic_params: Vec = class_data - .generic_params - .iter() - .chain(method_data.generic_params.iter()) - .cloned() - .collect(); - - let mut synthetic_effect_vars: Vec = Vec::new(); - let params: Vec<(Option, Ty)> = sig - .params - .iter() - .map(|(n, te)| { - let ty = if n.as_str() == "self" - && matches!(te, baml_compiler2_ast::TypeExpr::Unknown { .. }) - { - builtin_class_ty.clone() - } else if matches!(te, baml_compiler2_ast::TypeExpr::Function { .. }) { - // Function-typed parameter: use DirectParamRoot to create effect var. - let lowered = crate::lower_type_expr::lower_type_expr_with_fn_context( - db, - te, - baml_items, - stub_ns, - &method_generic_params, - &mut diags, - &crate::lower_type_expr::FnTypeLoweringContext::DirectParamRoot { - param_name: n.clone(), - }, - &mut synthetic_effect_vars, - ); - // Apply generic bindings to substitute type args. - crate::generics::substitute_ty(&lowered, &bindings) - } else { - crate::generics::lower_type_expr_with_generics( - db, - te, - self.package_items, - stub_ns, - &bindings, - &mut diags, - ) - }; - (Some(n.clone()), ty) - }) - .collect(); - let ret = sig - .return_type - .as_ref() - .map(|te| { - crate::generics::lower_type_expr_with_generics( - db, - te, - self.package_items, - stub_ns, - &bindings, - &mut diags, - ) - }) - .unwrap_or(Ty::Void { - attr: TyAttr::default(), - }); - - // Compute throws: if we have synthetic effect vars, use them. - let throws_ty = match synthetic_effect_vars.len() { - 0 => Ty::Never { - attr: TyAttr::default(), - }, - 1 => Ty::TypeVar(synthetic_effect_vars[0].clone(), TyAttr::default()), - _ => Ty::Union( - synthetic_effect_vars - .iter() - .map(|v| Ty::TypeVar(v.clone(), TyAttr::default())) - .collect(), - TyAttr::default(), - ), - }; - - // Discard diags — they will be reported at the definition site - // (the builtin .baml stub). We don't want to spam user code - // with unresolved-type errors from builtin signatures. - drop(diags); - return Some(BuiltinResolution::Method { - ty: Ty::Function { - params, - ret: Box::new(ret), - throws: Box::new(throws_ty), - attr: TyAttr::default(), - }, - class_loc, - func_loc, - }); - } - } - - // Fall back to fields (e.g. Request.method, Request.url). - for field in &class_data.fields { - if field.name == *member_name { - let mut diags = Vec::new(); - let field_ty = field - .type_expr - .as_ref() - .map(|te| { - crate::generics::lower_type_expr_with_generics( - db, - &te.expr, - self.package_items, - stub_ns, - &bindings, - &mut diags, - ) - }) - .unwrap_or(Ty::Unknown { - attr: TyAttr::default(), - }); - drop(diags); - return Some(BuiltinResolution::Field(field_ty)); + match self + .res_ctx + .resolve_builtin_member(db, &path, type_args, member_name)? + { + crate::package_interface::ResolvedBuiltinMember::Method { + ty, + class_loc, + func_loc, + } => Some(BuiltinResolution::Method { + ty, + class_loc, + func_loc, + }), + crate::package_interface::ResolvedBuiltinMember::Field(ty) => { + Some(BuiltinResolution::Field(ty)) } } - - None } /// Look up enum variants from the package items (via item tree). @@ -4019,26 +3790,7 @@ impl<'db> TypeInferenceBuilder<'db> { /// not just the current file's package. fn lookup_enum_variants(&self, enum_name: &crate::ty::QualifiedTypeName) -> Vec { let db = self.context.db(); - - // Resolve the package that owns the enum via res_ctx. - let items = if *enum_name.package() == self.package_id.name(db) { - self.package_items - } else { - match self.res_ctx.items_for_package(db, enum_name.package()) { - Some(items) => items, - None => return Vec::new(), - } - }; - - if let Some(Definition::Enum(enum_loc)) = - items.lookup_type(enum_name.namespace(), enum_name.name()) - { - let file = enum_loc.file(db); - let item_tree = baml_compiler2_ppir::file_item_tree(db, file); - let enum_data = &item_tree[enum_loc.id(db)]; - return enum_data.variants.iter().map(|v| v.name.clone()).collect(); - } - Vec::new() + self.res_ctx.lookup_enum_variants(db, enum_name) } // ── Evolving Container Mutations ───────────────────────────────────────── diff --git a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs index 18602215f4..2165c13952 100644 --- a/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs +++ b/baml_language/crates/baml_compiler2_tir/src/effective_throws.rs @@ -308,7 +308,7 @@ fn call_target_name( ) -> Option { match &body.exprs[callee_expr_id] { Expr::Path(segments) if !segments.is_empty() => Some(path_name(segments)), - Expr::FieldAccess { base, field } => { + Expr::FieldAccess { base, field } | Expr::OptionalFieldAccess { base, field } => { if let Some(Ty::Class(qn, _)) = expressions.get(base) { Some(class_method_key(qn, field)) } else { @@ -324,7 +324,7 @@ fn call_target_name( fn expr_to_path_segments(expr_id: ExprId, body: &ExprBody) -> Option> { match &body.exprs[expr_id] { Expr::Path(segments) if !segments.is_empty() => Some(segments.clone()), - Expr::FieldAccess { base, field } => { + Expr::FieldAccess { base, field } | Expr::OptionalFieldAccess { base, field } => { let mut segments = expr_to_path_segments(*base, body)?; segments.push(field.clone()); Some(segments) diff --git a/baml_language/crates/baml_compiler2_tir/src/package_interface.rs b/baml_language/crates/baml_compiler2_tir/src/package_interface.rs index 06f7d04914..defed4c390 100644 --- a/baml_language/crates/baml_compiler2_tir/src/package_interface.rs +++ b/baml_language/crates/baml_compiler2_tir/src/package_interface.rs @@ -8,6 +8,8 @@ //! `PackageResolutionContext` bundles a package's own `PackageItems` with its //! dependencies' `PackageInterface`s, providing unified lookup methods. +use std::collections::BTreeSet; + use baml_base::Name; use baml_compiler2_ast::BuiltinKind; use baml_compiler2_hir::{ @@ -33,8 +35,8 @@ pub struct PackageInterface { pub types: FxHashMap, FxHashMap>, /// All exported free functions: namespace path -> name -> `ExportedFunction` pub functions: FxHashMap, FxHashMap>, - /// Throw sets for all functions in this package (transitive, fully inferred). - pub throw_sets: FunctionThrowSets, + /// Canonical callable keys exported by this package. + pub callable_keys: BTreeSet, } /// A type exported from a package. @@ -60,9 +62,10 @@ pub enum ExportedType { #[derive(Debug, Clone, PartialEq)] pub struct ExportedFunction { pub name: Name, + pub callable_key: Name, pub params: Vec<(Name, Ty)>, pub return_type: Ty, - pub throws: Option, + pub declared_throws: Option, pub generic_params: Vec, pub builtin_kind: Option, } @@ -81,7 +84,7 @@ pub struct ResolvedFunction { pub name: Name, pub params: Vec<(Name, Ty)>, pub return_type: Ty, - pub throws: Option, + pub outward_throws: Option, pub generic_params: Vec, pub builtin_kind: Option, } @@ -93,11 +96,26 @@ pub struct ResolvedMethod { pub class_generic_params: Vec, } +/// Specialized resolution result for builtin class members. +/// +/// Builtin methods can be effect-polymorphic over direct callback parameters, so +/// they still need call-site-aware lowering even though normal dependency +/// resolution flows through exported interfaces. +pub enum ResolvedBuiltinMember<'db> { + Method { + ty: Ty, + class_loc: baml_compiler2_hir::loc::ClassLoc<'db>, + func_loc: baml_compiler2_hir::loc::FunctionLoc<'db>, + }, + Field(Ty), +} + /// Bundles a package's own items with its dependencies' pre-resolved interfaces. /// All cross-package lookups go through this context's methods. pub struct PackageResolutionContext<'db> { pub own_items: &'db PackageItems<'db>, pub dep_interfaces: Vec<(Name, &'db PackageInterface)>, + pub dep_packages: Vec<(Name, PackageId<'db>)>, pub own_package_name: Name, } @@ -106,6 +124,7 @@ impl PartialEq for PackageResolutionContext<'_> { std::ptr::eq(self.own_items, other.own_items) && self.own_package_name == other.own_package_name && self.dep_interfaces.len() == other.dep_interfaces.len() + && self.dep_packages == other.dep_packages && self .dep_interfaces .iter() @@ -192,6 +211,7 @@ pub fn package_interface<'db>(db: &'db dyn crate::Db, pkg_id: PackageId<'db>) -> let mut types: FxHashMap, FxHashMap> = FxHashMap::default(); let mut functions: FxHashMap, FxHashMap> = FxHashMap::default(); + let mut callable_keys = BTreeSet::new(); for (ns_path, ns_items) in &pkg_items.namespaces { // Export types @@ -277,7 +297,7 @@ pub fn package_interface<'db>(db: &'db dyn crate::Db, pkg_id: PackageId<'db>) -> }, ); - let throws = sig.throws.as_ref().map(|te| { + let explicit_throws = sig.throws.as_ref().map(|te| { lower_type_expr_in_ns( db, te, @@ -295,12 +315,17 @@ pub fn package_interface<'db>(db: &'db dyn crate::Db, pkg_id: PackageId<'db>) -> methods.push(ExportedFunction { name: method_data.name.clone(), + callable_key: crate::throw_inference::callable_throw_key( + db, method_loc, + ), params, return_type, - throws, + declared_throws: explicit_throws, generic_params: method_data.generic_params.clone(), builtin_kind, }); + callable_keys + .insert(crate::throw_inference::callable_throw_key(db, method_loc)); } let qtn = qualify_def(db, *def, name); @@ -386,7 +411,7 @@ pub fn package_interface<'db>(db: &'db dyn crate::Db, pkg_id: PackageId<'db>) -> }, ); - let throws = sig.throws.as_ref().map(|te| { + let explicit_throws = sig.throws.as_ref().map(|te| { lower_type_expr_in_ns( db, te, @@ -406,23 +431,22 @@ pub fn package_interface<'db>(db: &'db dyn crate::Db, pkg_id: PackageId<'db>) -> name.clone(), ExportedFunction { name: name.clone(), + callable_key: crate::throw_inference::callable_throw_key(db, *func_loc), params, return_type, - throws, + declared_throws: explicit_throws, generic_params: func_data.generic_params.clone(), builtin_kind, }, ); + callable_keys.insert(crate::throw_inference::callable_throw_key(db, *func_loc)); } } - // Compute throw sets for this package - let throw_sets = function_throw_sets(db, pkg_id); - PackageInterface { types, functions, - throw_sets: throw_sets.clone(), + callable_keys, } } @@ -463,6 +487,17 @@ fn build_self_type_for_class( } } +fn outward_throws_from_key(throw_sets: &FunctionThrowSets, key: &Name) -> Option { + let facts = throw_sets.transitive_for(key)?; + if facts.is_empty() { + None + } else { + Some(crate::throws_semantics::concrete_throws_ty_from_facts( + facts.clone(), + )) + } +} + // ── package_resolution_context Salsa query ───────────────────────────────── #[salsa::tracked(returns(ref))] @@ -472,6 +507,10 @@ pub fn package_resolution_context<'db>( ) -> PackageResolutionContext<'db> { let own_items = package_items(db, pkg_id); let deps = package_dependencies(db, pkg_id); + let dep_packages: Vec<(Name, PackageId<'db>)> = deps + .iter() + .map(|dep_id| (dep_id.name(db), *dep_id)) + .collect(); let dep_interfaces: Vec<(Name, &PackageInterface)> = deps .iter() .map(|dep_id| { @@ -483,6 +522,7 @@ pub fn package_resolution_context<'db>( PackageResolutionContext { own_items, dep_interfaces, + dep_packages, own_package_name: pkg_id.name(db), } } @@ -490,11 +530,17 @@ pub fn package_resolution_context<'db>( // ── PackageResolutionContext lookup methods ───────────────────────────────── impl<'db> PackageResolutionContext<'db> { + fn dep_package_id(&self, pkg_name: &Name) -> Option> { + self.dep_packages + .iter() + .find_map(|(dep_name, dep_id)| (dep_name == pkg_name).then_some(*dep_id)) + } + /// Get `PackageItems` for an accessible package (own or declared dependency). /// /// Returns `Some` for the own package and any declared dependency, /// `None` for undeclared packages. - pub fn items_for_package( + fn items_for_package( &self, db: &'db dyn crate::Db, pkg_name: &Name, @@ -513,6 +559,159 @@ impl<'db> PackageResolutionContext<'db> { } } + pub fn lookup_function( + &self, + db: &'db dyn crate::Db, + path: &[Name], + ns_context: &[Name], + ) -> Option<(ResolvedSource, ResolvedFunction)> { + let item = path.last()?; + if !ns_context.is_empty() { + let ns: Vec<_> = ns_context + .iter() + .chain(path[..path.len() - 1].iter()) + .cloned() + .collect(); + if let Some(result) = self.lookup_function_in_own_then_deps(db, &ns, item) { + return Some(result); + } + } + if ns_context.is_empty() { + if let Some(result) = + self.lookup_function_in_own_then_deps(db, &path[..path.len() - 1], item) + { + return Some(result); + } + } + if path.len() >= 2 { + if path[0].as_str() == "root" { + if let Some(function) = self.lookup_own_function(db, &path[1..path.len() - 1], item) + { + return Some((ResolvedSource::Item, function)); + } + } + for (dep_name, _dep_iface) in &self.dep_interfaces { + if &path[0] == dep_name { + if let Some(function) = + self.lookup_dep_function(db, dep_name, &path[1..path.len() - 1], item) + { + return Some((ResolvedSource::Builtin, function)); + } + } + } + } + None + } + + fn lookup_function_in_own_then_deps( + &self, + db: &'db dyn crate::Db, + namespace: &[Name], + item: &Name, + ) -> Option<(ResolvedSource, ResolvedFunction)> { + if let Some(function) = self.lookup_own_function(db, namespace, item) { + return Some((ResolvedSource::Item, function)); + } + for (dep_name, _dep_iface) in &self.dep_interfaces { + if let Some(function) = self.lookup_dep_function(db, dep_name, namespace, item) { + return Some((ResolvedSource::Builtin, function)); + } + } + None + } + + fn lookup_own_function( + &self, + db: &'db dyn crate::Db, + namespace: &[Name], + item: &Name, + ) -> Option { + let Definition::Function(func_loc) = self.own_items.lookup_value(namespace, item)? else { + return None; + }; + let item_tree = file_item_tree(db, func_loc.file(db)); + let func_data = &item_tree[func_loc.id(db)]; + let func_ns = file_package::file_package(db, func_loc.file(db)).namespace_path; + let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); + let body = baml_compiler2_hir::body::function_body(db, func_loc); + let mut diags = Vec::new(); + let mut params = Vec::new(); + for (param_name, param_te) in &sig.params { + let param_ty = lower_type_expr_in_ns( + db, + param_te, + self.own_items, + &func_ns, + &func_data.generic_params, + &mut diags, + ); + params.push((param_name.clone(), param_ty)); + } + let return_type = sig.return_type.as_ref().map_or( + Ty::Unknown { + attr: TyAttr::default(), + }, + |te| { + lower_type_expr_in_ns( + db, + te, + self.own_items, + &func_ns, + &func_data.generic_params, + &mut diags, + ) + }, + ); + let explicit_throws = sig.throws.as_ref().map(|te| { + lower_type_expr_in_ns( + db, + te, + self.own_items, + &func_ns, + &func_data.generic_params, + &mut diags, + ) + }); + let builtin_kind = match body.as_ref() { + baml_compiler2_hir::body::FunctionBody::Builtin(kind) => Some(*kind), + _ => None, + }; + drop(diags); + Some(ResolvedFunction { + name: func_data.name.clone(), + params, + return_type, + outward_throws: explicit_throws, + generic_params: func_data.generic_params.clone(), + builtin_kind, + }) + } + + fn lookup_dep_function( + &self, + db: &'db dyn crate::Db, + pkg_name: &Name, + namespace: &[Name], + item: &Name, + ) -> Option { + let dep_iface = self + .dep_interfaces + .iter() + .find_map(|(dep_name, dep_iface)| (dep_name == pkg_name).then_some(*dep_iface))?; + let exported = dep_iface.lookup_function(namespace, item)?; + let dep_pkg_id = self.dep_package_id(pkg_name)?; + Some(ResolvedFunction { + name: exported.name.clone(), + params: exported.params.clone(), + return_type: exported.return_type.clone(), + outward_throws: exported.declared_throws.clone().or_else(|| { + outward_throws_from_key(function_throw_sets(db, dep_pkg_id), &exported.callable_key) + }), + generic_params: exported.generic_params.clone(), + builtin_kind: exported.builtin_kind, + }) + } + /// Resolve a type by path. Own-package via `PackageItems`, then deps. pub fn resolve_type( &self, @@ -616,9 +815,12 @@ impl<'db> PackageResolutionContext<'db> { // dep-prefixed search (parity with resolve_type) for (dep_name, _dep_iface) in &self.dep_interfaces { if &path[0] == dep_name { - let dep_pkg_id = PackageId::new(db, dep_name.clone()); - let dep_items = package_items(db, dep_pkg_id); - if let Some(def) = dep_items.lookup_value(&path[1..path.len() - 1], item) { + if let Some(def) = self.lookup_value_definition_in_package( + db, + dep_name, + &path[1..path.len() - 1], + item, + ) { return Some((ResolvedSource::Builtin, def)); } } @@ -638,6 +840,93 @@ impl<'db> PackageResolutionContext<'db> { None } + pub fn lookup_value_definition_in_package( + &self, + db: &'db dyn crate::Db, + pkg_name: &Name, + namespace: &[Name], + item: &Name, + ) -> Option> { + let pkg_items = self.items_for_package(db, pkg_name)?; + pkg_items.lookup_value(namespace, item) + } + + pub fn lookup_type_definition_in_package( + &self, + db: &'db dyn crate::Db, + pkg_name: &Name, + namespace: &[Name], + item: &Name, + ) -> Option> { + let pkg_items = self.items_for_package(db, pkg_name)?; + pkg_items.lookup_type(namespace, item) + } + + pub fn lookup_class_loc( + &self, + db: &'db dyn crate::Db, + qtn: &QualifiedTypeName, + ) -> Option> { + match self.lookup_type_definition_in_package( + db, + qtn.package(), + qtn.namespace(), + qtn.name(), + )? { + Definition::Class(class_loc) => Some(class_loc), + _ => None, + } + } + + pub fn lookup_enum_loc( + &self, + db: &'db dyn crate::Db, + qtn: &QualifiedTypeName, + ) -> Option> { + match self.lookup_type_definition_in_package( + db, + qtn.package(), + qtn.namespace(), + qtn.name(), + )? { + Definition::Enum(enum_loc) => Some(enum_loc), + _ => None, + } + } + + pub fn lookup_enum_variants( + &self, + db: &'db dyn crate::Db, + enum_name: &QualifiedTypeName, + ) -> Vec { + if enum_name.package().as_str() == self.own_package_name.as_str() { + let Some(Definition::Enum(enum_loc)) = self.lookup_type_definition_in_package( + db, + enum_name.package(), + enum_name.namespace(), + enum_name.name(), + ) else { + return Vec::new(); + }; + let item_tree = file_item_tree(db, enum_loc.file(db)); + let enum_data = &item_tree[enum_loc.id(db)]; + enum_data.variants.iter().map(|v| v.name.clone()).collect() + } else { + self.dep_interfaces + .iter() + .find_map(|(dep_name, dep_iface)| { + if dep_name != enum_name.package() { + return None; + } + match dep_iface.lookup_type(enum_name.namespace(), enum_name.name())? { + ExportedType::Enum { variants, .. } => Some(variants.clone()), + _ => None, + } + }) + .unwrap_or_default() + } + } + /// Look up class fields. Dual dispatch: /// - Own-package: `ItemTree` -> lower fields + apply class-level generic substitution /// - Dependency: `ExportedType::Class` { fields } + apply class-level generic substitution @@ -726,6 +1015,9 @@ impl<'db> PackageResolutionContext<'db> { if dep_name != class_pkg { continue; } + let Some(dep_pkg_id) = self.dep_package_id(dep_name) else { + continue; + }; if let Some(ExportedType::Class { methods, generic_params, @@ -738,7 +1030,12 @@ impl<'db> PackageResolutionContext<'db> { name: method.name.clone(), params: method.params.clone(), return_type: method.return_type.clone(), - throws: method.throws.clone(), + outward_throws: method.declared_throws.clone().or_else(|| { + outward_throws_from_key( + function_throw_sets(db, dep_pkg_id), + &method.callable_key, + ) + }), generic_params: method.generic_params.clone(), builtin_kind: method.builtin_kind, }, @@ -761,9 +1058,11 @@ impl<'db> PackageResolutionContext<'db> { &resolved_method.function.return_type, &bindings, ); - if let Some(ref throws) = resolved_method.function.throws.clone() { - resolved_method.function.throws = - Some(crate::generics::substitute_ty(throws, &bindings)); + if let Some(ref outward_throws) = + resolved_method.function.outward_throws.clone() + { + resolved_method.function.outward_throws = + Some(crate::generics::substitute_ty(outward_throws, &bindings)); } } return Some(resolved_method); @@ -774,6 +1073,232 @@ impl<'db> PackageResolutionContext<'db> { } } + pub fn lookup_class_method_locs( + &self, + db: &'db dyn crate::Db, + class_name: &NominalTypeRef, + method_name: &Name, + ) -> Option<( + baml_compiler2_hir::loc::ClassLoc<'db>, + baml_compiler2_hir::loc::FunctionLoc<'db>, + )> { + let class_loc = self.lookup_class_loc(db, class_name.qtn())?; + let item_tree = file_item_tree(db, class_loc.file(db)); + let class_data = &item_tree[class_loc.id(db)]; + let func_loc = class_data.methods.iter().find_map(|&method_id| { + let method_data = &item_tree[method_id]; + if method_data.name == *method_name { + Some(baml_compiler2_hir::loc::FunctionLoc::new( + db, + class_loc.file(db), + method_id, + )) + } else { + None + } + })?; + Some((class_loc, func_loc)) + } + + /// Resolve a member on a builtin class declared in the `baml` package. + /// + /// This keeps the raw builtin stub inspection inside `PackageResolutionContext` + /// so callers do not need to reach into foreign `PackageItems` or `ItemTree` + /// directly. Method parameters still get specialized lowering so omitted + /// throws on direct callback params become synthetic effect variables. + pub fn resolve_builtin_member( + &self, + db: &'db dyn crate::Db, + class_path: &[Name], + type_args: &[Ty], + member_name: &Name, + ) -> Option> { + let baml_items = self.items_for_package(db, &Name::new("baml"))?; + let item = class_path.last()?; + let def = baml_items.lookup_type(&class_path[..class_path.len() - 1], item)?; + let Definition::Class(class_loc) = def else { + return None; + }; + + let file = class_loc.file(db); + let stub_pkg = file_package::file_package(db, file); + let stub_ns: &[Name] = &stub_pkg.namespace_path; + let item_tree = file_item_tree(db, file); + let class_data = &item_tree[class_loc.id(db)]; + + let mut bindings = crate::generics::bind_type_vars(&class_data.generic_params, type_args); + + for &method_id in &class_data.methods { + let method_data = &item_tree[method_id]; + if method_data.name != *member_name { + continue; + } + + for gp in &method_data.generic_params { + bindings + .entry(gp.clone()) + .or_insert_with(|| Ty::TypeVar(gp.clone(), TyAttr::default())); + } + + let func_loc = baml_compiler2_hir::loc::FunctionLoc::new(db, file, method_id); + let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); + let mut diags = Vec::new(); + + let builtin_class_ty = if type_args.is_empty() { + Ty::Class( + QualifiedTypeName::new_with_generic_params( + stub_pkg.package.clone(), + stub_pkg.namespace_path.clone(), + class_data.name.clone(), + class_data.generic_params.clone(), + ) + .into(), + TyAttr::default(), + ) + } else if type_args.len() == 1 { + match class_data.name.as_str() { + "Array" => Ty::List(Box::new(type_args[0].clone()), TyAttr::default()), + _ => Ty::Class( + QualifiedTypeName::new( + stub_pkg.package.clone(), + stub_pkg.namespace_path.clone(), + class_data.name.clone(), + ) + .into(), + TyAttr::default(), + ), + } + } else if type_args.len() == 2 { + match class_data.name.as_str() { + "Map" => Ty::Map( + Box::new(type_args[0].clone()), + Box::new(type_args[1].clone()), + TyAttr::default(), + ), + _ => Ty::Class( + QualifiedTypeName::new( + stub_pkg.package.clone(), + stub_pkg.namespace_path.clone(), + class_data.name.clone(), + ) + .into(), + TyAttr::default(), + ), + } + } else { + Ty::Class( + QualifiedTypeName::new( + stub_pkg.package.clone(), + stub_pkg.namespace_path.clone(), + class_data.name.clone(), + ) + .into(), + TyAttr::default(), + ) + }; + + let method_generic_params: Vec = class_data + .generic_params + .iter() + .chain(method_data.generic_params.iter()) + .cloned() + .collect(); + + let mut synthetic_effect_vars: Vec = Vec::new(); + let params: Vec<(Option, Ty)> = sig + .params + .iter() + .map(|(n, te)| { + let ty = if n.as_str() == "self" + && matches!(te, baml_compiler2_ast::TypeExpr::Unknown { .. }) + { + builtin_class_ty.clone() + } else if matches!(te, baml_compiler2_ast::TypeExpr::Function { .. }) { + let lowered = crate::lower_type_expr::lower_type_expr_with_fn_context( + db, + te, + baml_items, + stub_ns, + &method_generic_params, + &mut diags, + &crate::lower_type_expr::FnTypeLoweringContext::DirectParamRoot { + param_name: n.clone(), + }, + &mut synthetic_effect_vars, + ); + crate::generics::substitute_ty(&lowered, &bindings) + } else { + crate::generics::lower_type_expr_with_generics( + db, te, baml_items, stub_ns, &bindings, &mut diags, + ) + }; + (Some(n.clone()), ty) + }) + .collect(); + + let ret = sig + .return_type + .as_ref() + .map(|te| { + crate::generics::lower_type_expr_with_generics( + db, te, baml_items, stub_ns, &bindings, &mut diags, + ) + }) + .unwrap_or(Ty::Void { + attr: TyAttr::default(), + }); + + let throws_ty = match synthetic_effect_vars.len() { + 0 => Ty::Never { + attr: TyAttr::default(), + }, + 1 => Ty::TypeVar(synthetic_effect_vars[0].clone(), TyAttr::default()), + _ => Ty::Union( + synthetic_effect_vars + .iter() + .map(|v| Ty::TypeVar(v.clone(), TyAttr::default())) + .collect(), + TyAttr::default(), + ), + }; + + drop(diags); + return Some(ResolvedBuiltinMember::Method { + ty: Ty::Function { + params, + ret: Box::new(ret), + throws: Box::new(throws_ty), + attr: TyAttr::default(), + }, + class_loc, + func_loc, + }); + } + + for field in &class_data.fields { + if field.name != *member_name { + continue; + } + + let mut diags = Vec::new(); + let field_ty = field + .type_expr + .as_ref() + .map(|te| { + crate::generics::lower_type_expr_with_generics( + db, &te.expr, baml_items, stub_ns, &bindings, &mut diags, + ) + }) + .unwrap_or(Ty::Unknown { + attr: TyAttr::default(), + }); + drop(diags); + return Some(ResolvedBuiltinMember::Field(field_ty)); + } + + None + } + /// Apply class-level generic substitution to a `ResolvedMethod`'s function signature. /// Uses `generic_params` from the own-package item tree. fn apply_method_substitution( @@ -799,9 +1324,9 @@ impl<'db> PackageResolutionContext<'db> { .collect(); resolved_method.function.return_type = crate::generics::substitute_ty(&resolved_method.function.return_type, &bindings); - if let Some(ref throws) = resolved_method.function.throws.clone() { - resolved_method.function.throws = - Some(crate::generics::substitute_ty(throws, &bindings)); + if let Some(ref outward_throws) = resolved_method.function.outward_throws.clone() { + resolved_method.function.outward_throws = + Some(crate::generics::substitute_ty(outward_throws, &bindings)); } } @@ -904,7 +1429,7 @@ impl<'db> PackageResolutionContext<'db> { ) }, ); - let throws = sig.throws.as_ref().map(|te| { + let explicit_throws = sig.throws.as_ref().map(|te| { lower_type_expr_in_ns(db, te, self.own_items, &ns, &all_generic_params, &mut diags) }); @@ -918,7 +1443,7 @@ impl<'db> PackageResolutionContext<'db> { name: method_data.name.clone(), params, return_type, - throws, + outward_throws: explicit_throws, generic_params: method_data.generic_params.clone(), builtin_kind, }, diff --git a/baml_language/crates/baml_compiler2_tir/src/resolve.rs b/baml_language/crates/baml_compiler2_tir/src/resolve.rs index 8b64afbbf0..82818613ab 100644 --- a/baml_language/crates/baml_compiler2_tir/src/resolve.rs +++ b/baml_language/crates/baml_compiler2_tir/src/resolve.rs @@ -118,20 +118,27 @@ pub fn resolve_name_at_in_scope<'db>( if let Some((_source, _ty)) = res_ctx.resolve_type(db, name_path, &pkg_info.namespace_path) { - let pkg_items = res_ctx.own_items; - if let Some(def) = pkg_items.lookup_type(&pkg_info.namespace_path, name) { + if let Some(def) = res_ctx.lookup_type_definition_in_package( + db, + &pkg_info.package, + &pkg_info.namespace_path, + name, + ) { return ResolvedName::Item(def); } - // The type was found in deps — search deps. - // Dep builtins are in the root namespace (&[]). for (dep_name, _) in &res_ctx.dep_interfaces { - if let Some(dep_items) = res_ctx.items_for_package(db, dep_name) { - if let Some(def) = dep_items - .lookup_type(&pkg_info.namespace_path, name) - .or_else(|| dep_items.lookup_type(&[], name)) - { - return ResolvedName::Builtin(def); - } + if let Some(def) = res_ctx + .lookup_type_definition_in_package( + db, + dep_name, + &pkg_info.namespace_path, + name, + ) + .or_else(|| { + res_ctx.lookup_type_definition_in_package(db, dep_name, &[], name) + }) + { + return ResolvedName::Builtin(def); } } } @@ -176,20 +183,16 @@ pub fn resolve_path_at<'db>( let own_pkg_id = PackageId::new(db, pkg_info.package); let res_ctx = crate::package_interface::package_resolution_context(db, own_pkg_id); - let Some(pkg_items) = res_ctx.items_for_package(db, &pkg_name) else { - return ResolvedName::Unknown; - }; - let after_pkg = &segments[1..]; // The path already includes namespace segments, so look up directly. let item = after_pkg .last() .expect("multi-segment path has elements after pkg prefix"); let ns = &after_pkg[..after_pkg.len() - 1]; - if let Some(def) = pkg_items.lookup_value(ns, item) { + if let Some(def) = res_ctx.lookup_value_definition_in_package(db, &pkg_name, ns, item) { return ResolvedName::Builtin(def); } - if let Some(def) = pkg_items.lookup_type(ns, item) { + if let Some(def) = res_ctx.lookup_type_definition_in_package(db, &pkg_name, ns, item) { return ResolvedName::Builtin(def); } diff --git a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs index 8b7802ab26..1a578c1cc9 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs @@ -11,7 +11,7 @@ use baml_base::Name; use baml_compiler2_ast::{Expr, ExprBody, Literal, Pattern, TypeExpr}; use baml_compiler2_hir::{ contributions::Definition, - package::{PackageId, PackageItems, package_dependencies, package_items}, + package::{PackageId, PackageItems, package_items}, }; use crate::{ @@ -61,7 +61,7 @@ impl FunctionThrowSets { } } -#[salsa::tracked(returns(ref))] +#[salsa::tracked(returns(ref), cycle_initial=function_throw_sets_initial)] pub fn function_throw_sets<'db>( db: &'db dyn crate::Db, package_id: PackageId<'db>, @@ -81,16 +81,8 @@ pub fn function_throw_sets<'db>( } } } - // Load dependency interfaces for cross-package throw lookup - let dep_interfaces: Vec<(Name, &crate::package_interface::PackageInterface)> = - package_dependencies(db, package_id) - .iter() - .map(|dep_id| { - let name = dep_id.name(db); - let iface = crate::package_interface::package_interface(db, *dep_id); - (name, iface) - }) - .collect(); + let dep_packages = &res_ctx.dep_packages; + let dep_interfaces = &res_ctx.dep_interfaces; let mut graph: crate::analysis::AnalysisGraph = crate::analysis::AnalysisGraph::new(); @@ -101,12 +93,12 @@ pub fn function_throw_sets<'db>( let mut direct_facts: BTreeMap> = BTreeMap::new(); for ns in pkg_items.namespaces.values() { - for (short_name, def) in &ns.values { + for def in ns.values.values() { let Definition::Function(func_loc) = def else { continue; }; - let key = function_key(db, *func_loc, short_name); + let key = callable_throw_key(db, *func_loc); let sig = baml_compiler2_hir::signature::function_signature(db, *func_loc); let body = baml_compiler2_hir::body::function_body(db, *func_loc); let item_tree = baml_compiler2_hir::file_item_tree(db, func_loc.file(db)); @@ -189,11 +181,8 @@ pub fn function_throw_sets<'db>( for &method_id in &class_data.methods { let method_data = &item_tree[method_id]; - let method_name = &method_data.name; let func_loc = baml_compiler2_hir::loc::FunctionLoc::new(db, file, method_id); - // Key as "ClassName.method_name" (with namespace prefix if any). - let method_short = Name::new(format!("{class_name}.{method_name}")); - let key = function_key(db, func_loc, &method_short); + let key = callable_throw_key(db, func_loc); let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); let body = baml_compiler2_hir::body::function_body(db, func_loc); @@ -282,12 +271,12 @@ pub fn function_throw_sets<'db>( continue; } for to in targets { - if let Some(dep_throws) = lookup_dep_throw_set(&dep_interfaces, to) { + if let Some(dep_throws) = lookup_dep_throw_set(db, dep_packages, dep_interfaces, to) { // Cross-package: merge dependency's transitive throw facts into caller's direct facts direct_facts .entry(from.clone()) .or_default() - .extend(dep_throws.iter().cloned()); + .extend(dep_throws); } else { // Same-package: will add edge after nodes are added // (edges added below) @@ -306,7 +295,7 @@ pub fn function_throw_sets<'db>( continue; } for to in targets { - if lookup_dep_throw_set(&dep_interfaces, to).is_none() { + if lookup_dep_throw_set(db, dep_packages, dep_interfaces, to).is_none() { graph.add_edge(from.clone(), to.clone()); } } @@ -326,6 +315,17 @@ pub fn function_throw_sets<'db>( FunctionThrowSets { direct, transitive } } +fn function_throw_sets_initial<'db>( + _db: &'db dyn crate::Db, + _id: salsa::Id, + _package_id: PackageId<'db>, +) -> FunctionThrowSets { + FunctionThrowSets { + direct: BTreeMap::new(), + transitive: BTreeMap::new(), + } +} + /// Build the throw-set lookup key for a function given its namespace path and short name. /// /// For top-level functions the key is just the short name; for namespaced @@ -343,14 +343,25 @@ pub fn throw_set_key(namespace_path: &[Name], short_name: &Name) -> Name { } } -fn function_key<'db>( +pub fn callable_throw_key<'db>( db: &'db dyn crate::Db, func: baml_compiler2_hir::loc::FunctionLoc<'db>, - short_name: &Name, ) -> Name { let file = func.file(db); + let item_tree = baml_compiler2_hir::file_item_tree(db, file); + let func_data = &item_tree[func.id(db)]; let pkg = baml_compiler2_hir::file_package::file_package(db, file); - throw_set_key(&pkg.namespace_path, short_name) + let short_name = item_tree + .classes + .values() + .find_map(|class_data| { + class_data + .methods + .contains(&func.id(db)) + .then(|| Name::new(format!("{}.{}", class_data.name, func_data.name))) + }) + .unwrap_or_else(|| func_data.name.clone()); + throw_set_key(&pkg.namespace_path, &short_name) } pub fn collect_direct_throws<'db>( @@ -628,151 +639,531 @@ fn collect_member_field_call_throws<'db>( // with TypeVar type_args, e.g. Handler not bare Handler class_context: Option<(&Name, &[Name])>, ) -> (BTreeSet, BTreeSet) { - let mut direct_facts = BTreeSet::new(); - let mut extra_edges = BTreeSet::new(); - - let pkg_items = res_ctx.own_items; - // Build param name → Ty map from the function signature. - let param_types: HashMap = sig - .params - .iter() - .map(|(name, te)| { - let mut diags = Vec::new(); - let ty = - lower_type_expr_in_ns(db, te, pkg_items, ns_context, generic_params, &mut diags); - (name.clone(), ty) - }) - .collect(); + MemberFieldCallCollector::new( + db, + res_ctx, + ns_context, + generic_params, + sig, + body, + aliases, + class_context, + ) + .collect() +} - // Extend with for-loop variable types. - // For `for (let elem in collection)` where `collection` is a known param of type `T[]`, - // map `elem` → `T`. This handles patterns like `for task in tasks` where - // `tasks: Task[]`. - let mut local_types: HashMap = param_types.clone(); - for (_, stmt) in body.stmts.iter() { - if let baml_compiler2_ast::Stmt::For { - binding, - collection, - .. - } = stmt - { - // Get the binding name - let binding_name = match &body.patterns[*binding] { - baml_compiler2_ast::Pattern::Binding(name) - | baml_compiler2_ast::Pattern::TypedBinding { name, .. } => name.clone(), - _ => continue, - }; - // Get the collection's type by resolving its path to a param - let collection_ty = match &body.exprs[*collection] { - Expr::Path(segments) if segments.len() == 1 => { - param_types.get(&segments[0]).cloned() +struct MemberFieldCallCollector<'a, 'db> { + db: &'db dyn crate::Db, + res_ctx: &'a crate::package_interface::PackageResolutionContext<'db>, + ns_context: &'a [Name], + generic_params: &'a [Name], + body: &'a ExprBody, + aliases: &'a HashMap, + class_context: Option<(&'a Name, &'a [Name])>, + direct_facts: BTreeSet, + extra_edges: BTreeSet, + initial_locals: HashMap, +} + +impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { + #[allow(clippy::too_many_arguments)] + fn new( + db: &'db dyn crate::Db, + res_ctx: &'a crate::package_interface::PackageResolutionContext<'db>, + ns_context: &'a [Name], + generic_params: &'a [Name], + sig: &'a baml_compiler2_hir::signature::FunctionSignature, + body: &'a ExprBody, + aliases: &'a HashMap, + class_context: Option<(&'a Name, &'a [Name])>, + ) -> Self { + let pkg_items = res_ctx.own_items; + let initial_locals = sig + .params + .iter() + .map(|(name, te)| { + let mut diags = Vec::new(); + let ty = lower_type_expr_in_ns( + db, + te, + pkg_items, + ns_context, + generic_params, + &mut diags, + ); + (name.clone(), ty) + }) + .collect(); + Self { + db, + res_ctx, + ns_context, + generic_params, + body, + aliases, + class_context, + direct_facts: BTreeSet::new(), + extra_edges: BTreeSet::new(), + initial_locals, + } + } + + fn collect(mut self) -> (BTreeSet, BTreeSet) { + if let Some(root_expr) = self.body.root_expr { + let mut env = self.initial_locals.clone(); + self.visit_expr(root_expr, &mut env); + } + (self.direct_facts, self.extra_edges) + } + + fn visit_expr(&mut self, expr_id: baml_compiler2_ast::ExprId, env: &mut HashMap) { + match &self.body.exprs[expr_id] { + Expr::Literal(_) + | Expr::ByteStringLiteral(_) + | Expr::Null + | Expr::Path(_) + | Expr::Lambda(_) + | Expr::Missing => {} + Expr::If { + condition, + then_branch, + else_branch, + } => { + self.visit_expr(*condition, env); + let mut then_env = env.clone(); + self.visit_expr(*then_branch, &mut then_env); + if let Some(else_expr) = else_branch { + let mut else_env = env.clone(); + self.visit_expr(*else_expr, &mut else_env); } - _ => None, - }; - if let Some(coll_ty) = collection_ty { - // Peel List type to get element type - let elem_ty = match coll_ty { - Ty::List(inner, _) => Some(*inner), - _ => None, - }; - if let Some(elem) = elem_ty { - local_types.insert(binding_name, elem); + } + Expr::Match { + scrutinee, arms, .. + } => { + self.visit_expr(*scrutinee, env); + for arm_id in arms { + let arm = &self.body.match_arms[*arm_id]; + let mut arm_env = env.clone(); + if let Some(guard) = arm.guard { + self.visit_expr(guard, &mut arm_env); + } + self.visit_expr(arm.body, &mut arm_env); } } + Expr::Catch { base, clauses } => { + self.visit_expr(*base, env); + for clause in clauses { + for arm_id in &clause.arms { + let arm = &self.body.catch_arms[*arm_id]; + let mut arm_env = env.clone(); + self.visit_expr(arm.body, &mut arm_env); + } + } + } + Expr::Throw { value } => self.visit_expr(*value, env), + Expr::Binary { lhs, rhs, .. } => { + self.visit_expr(*lhs, env); + self.visit_expr(*rhs, env); + } + Expr::Unary { expr, .. } | Expr::OptionalChain { expr } => { + self.visit_expr(*expr, env); + } + Expr::Call { callee, args } | Expr::OptionalCall { callee, args } => { + self.collect_member_call(*callee, env); + self.visit_expr(*callee, env); + for arg in args { + self.visit_expr(*arg, env); + } + } + Expr::Object { + fields, spreads, .. + } => { + for (_, value) in fields { + self.visit_expr(*value, env); + } + for spread in spreads { + self.visit_expr(spread.expr, env); + } + } + Expr::Array { elements } => { + for elem in elements { + self.visit_expr(*elem, env); + } + } + Expr::Map { entries } => { + for (key, value) in entries { + self.visit_expr(*key, env); + self.visit_expr(*value, env); + } + } + Expr::Block { stmts, tail_expr } => { + let mut block_env = env.clone(); + for stmt_id in stmts { + self.visit_stmt(*stmt_id, &mut block_env); + } + if let Some(tail_expr) = tail_expr { + self.visit_expr(*tail_expr, &mut block_env); + } + } + Expr::FieldAccess { base, .. } | Expr::OptionalFieldAccess { base, .. } => { + self.visit_expr(*base, env); + } + Expr::Index { base, index } | Expr::OptionalIndex { base, index } => { + self.visit_expr(*base, env); + self.visit_expr(*index, env); + } } } - for (_, expr) in body.exprs.iter() { - let callee_id = match expr { - Expr::Call { callee, .. } | Expr::OptionalCall { callee, .. } => *callee, - _ => continue, - }; - - // Match FieldAccess { base, field } pattern - let (base_id, field) = match &body.exprs[callee_id] { - Expr::FieldAccess { base, field } => (*base, field), - _ => continue, - }; - - // Resolve base to a type: direct parameter, for-loop variable, or `self` - let base_ty = match &body.exprs[base_id] { - Expr::Path(segments) if segments.len() == 1 => { - let base_name = &segments[0]; - if base_name.as_str() == "self" { - // self-based member calls — resolve self type from class_context. - // Build accurate self type with TypeVar type_args so that - // lookup_class_fields applies correct generic substitution. - class_context.and_then(|(class_name, class_generic_params)| { - let def = pkg_items.lookup_type(ns_context, class_name)?; - match def { - Definition::Class(_) => { - let qtn = qualify_def(db, def, class_name); - let type_args: Vec = class_generic_params - .iter() - .map(|name| Ty::TypeVar(name.clone(), TyAttr::default())) - .collect(); - let nominal = - crate::ty::NominalTypeRef::new_with_type_args(qtn, type_args); - Some(Ty::Class(nominal, TyAttr::default())) - } - _ => None, - } - }) - } else { - local_types.get(base_name).cloned() + fn visit_stmt(&mut self, stmt_id: baml_compiler2_ast::StmtId, env: &mut HashMap) { + match &self.body.stmts[stmt_id] { + baml_compiler2_ast::Stmt::Expr(expr_id) => self.visit_expr(*expr_id, env), + baml_compiler2_ast::Stmt::Let { + pattern, + type_annotation, + initializer, + .. + } => { + let inferred_ty = initializer.and_then(|expr_id| { + self.visit_expr(expr_id, env); + self.resolve_expr_ty(expr_id, env) + }); + if let Some((binding_name, binding_ty)) = + self.resolve_binding_ty(*pattern, *type_annotation, inferred_ty) + { + env.insert(binding_name, binding_ty); } } - _ => None, - }; + baml_compiler2_ast::Stmt::While { + condition, + body, + after, + .. + } => { + self.visit_expr(*condition, env); + let mut body_env = env.clone(); + self.visit_expr(*body, &mut body_env); + if let Some(after_stmt) = after { + self.visit_stmt(*after_stmt, env); + } + } + baml_compiler2_ast::Stmt::For { + binding, + collection, + body, + } => { + self.visit_expr(*collection, env); + let mut body_env = env.clone(); + if let Some(binding_name) = self.binding_name(*binding) + && let Some(elem_ty) = self.resolve_collection_element_ty(*collection, env) + { + body_env.insert(binding_name, elem_ty); + } + self.visit_expr(*body, &mut body_env); + } + baml_compiler2_ast::Stmt::Return(expr) => { + if let Some(expr_id) = expr { + self.visit_expr(*expr_id, env); + } + } + baml_compiler2_ast::Stmt::Assign { target, value } + | baml_compiler2_ast::Stmt::AssignOp { target, value, .. } => { + self.visit_expr(*target, env); + self.visit_expr(*value, env); + if let Some((binding_name, binding_ty)) = + self.resolve_assignment_ty(*target, *value, env) + { + env.insert(binding_name, binding_ty); + } + } + baml_compiler2_ast::Stmt::Throw { value } => self.visit_expr(*value, env), + baml_compiler2_ast::Stmt::Break + | baml_compiler2_ast::Stmt::Continue + | baml_compiler2_ast::Stmt::Missing + | baml_compiler2_ast::Stmt::HeaderComment { .. } => {} + } + } - let Some(base_ty) = base_ty else { - continue; + fn collect_member_call( + &mut self, + callee_id: baml_compiler2_ast::ExprId, + env: &HashMap, + ) { + let (base_id, field) = match &self.body.exprs[callee_id] { + Expr::FieldAccess { base, field } | Expr::OptionalFieldAccess { base, field } => { + (*base, field) + } + _ => return, }; - // Resolve the member on the base type - let resolved_base = crate::throws_semantics::resolve_alias_chain(&base_ty, aliases); + let Some(base_ty) = self.resolve_expr_ty(base_id, env) else { + return; + }; + let resolved_base = crate::throws_semantics::resolve_alias_chain(&base_ty, self.aliases); let Ty::Class(class_name, _) = &resolved_base else { - continue; + return; }; - // Look up the field via PackageResolutionContext (with generic substitution) - let fields = res_ctx.lookup_class_fields(db, class_name); - if let Some((_, field_ty)) = fields.iter().find(|(n, _)| n == field) { - // Function-typed field → extract throws as direct facts - if let Some(facts) = function_throws_facts(field_ty, aliases) { - direct_facts.extend(facts.into_iter().filter(|fact| { + let fields = self.res_ctx.lookup_class_fields(self.db, class_name); + if let Some((_, field_ty)) = fields.iter().find(|(name, _)| name == field) { + if let Some(facts) = function_throws_facts(field_ty, self.aliases) { + self.direct_facts.extend(facts.into_iter().filter(|fact| { !matches!(fact, Ty::TypeVar(_, _) | Ty::Never { .. } | Ty::Void { .. }) })); } - continue; + return; } - // Check if it's a named method → validate via res_ctx and add as call-graph edge. - // We MUST validate the method exists before adding an edge, because - // AnalysisGraph::add_edge (analysis.rs:54-58) auto-creates missing target - // nodes with empty facts. A bogus edge would inject a phantom node and - // alter graph topology, potentially masking real throw propagation. - // We also must use the class declaration namespace for the key, NOT the caller's - // ns_context, so the edge matches the declaration-side key built by function_key(). - if let Some(_resolved_method) = res_ctx.lookup_class_method(db, class_name, field) { + if let Some(_resolved_method) = self.res_ctx.lookup_class_method(self.db, class_name, field) + { let class_ns = class_name.namespace(); let method_short = Name::new(format!("{}.{}", class_name.name(), field)); let method_key = throw_set_key(class_ns, &method_short); - extra_edges.insert(method_key); + self.extra_edges.insert(method_key); + } + } + + fn resolve_expr_ty( + &self, + expr_id: baml_compiler2_ast::ExprId, + env: &HashMap, + ) -> Option { + match &self.body.exprs[expr_id] { + Expr::Path(segments) if segments.len() == 1 => { + let name = &segments[0]; + if name.as_str() == "self" { + self.self_ty() + } else { + env.get(name).cloned() + } + } + Expr::Path(segments) => self.named_callable_ty(segments), + Expr::FieldAccess { base, field } | Expr::OptionalFieldAccess { base, field } => { + let base_ty = self.resolve_expr_ty(*base, env)?; + self.resolve_member_ty(&base_ty, field) + } + Expr::Call { callee, .. } | Expr::OptionalCall { callee, .. } => { + let callee_ty = self.resolve_expr_ty(*callee, env)?; + match callee_ty { + Ty::Function { ret, .. } => Some(*ret), + Ty::Optional(inner, _) => match *inner { + Ty::Function { ret, .. } => Some(*ret), + _ => None, + }, + _ => None, + } + } + Expr::Object { + type_name: Some(name), + type_args, + .. + } => { + let def = self.res_ctx.own_items.lookup_type(self.ns_context, name)?; + match def { + Definition::Class(_) => { + let qtn = qualify_def(self.db, def, name); + let mut diags = Vec::new(); + let lowered_type_args = type_args + .iter() + .map(|te| { + lower_type_expr_in_ns( + self.db, + te, + self.res_ctx.own_items, + self.ns_context, + self.generic_params, + &mut diags, + ) + }) + .collect(); + Some(Ty::Class( + crate::ty::NominalTypeRef::new_with_type_args(qtn, lowered_type_args), + TyAttr::default(), + )) + } + Definition::Enum(_) => { + let qtn = qualify_def(self.db, def, name); + Some(Ty::Enum( + crate::ty::NominalTypeRef::new_with_type_args(qtn, Vec::new()), + TyAttr::default(), + )) + } + _ => None, + } + } + Expr::Array { elements } => { + let element_tys: Vec = elements + .iter() + .map(|expr_id| self.resolve_expr_ty(*expr_id, env)) + .collect::>>()?; + let element_ty = collapse_types(element_tys)?; + Some(Ty::List(Box::new(element_ty), TyAttr::default())) + } + Expr::OptionalChain { expr } => self.resolve_expr_ty(*expr, env), + _ => None, + } + } + + fn resolve_member_ty(&self, base_ty: &Ty, field: &Name) -> Option { + let resolved_base = crate::throws_semantics::resolve_alias_chain(base_ty, self.aliases); + let Ty::Class(class_name, _) = &resolved_base else { + return None; + }; + + let fields = self.res_ctx.lookup_class_fields(self.db, class_name); + if let Some((_, field_ty)) = fields.iter().find(|(name, _)| name == field) { + return Some(field_ty.clone()); } + + let method = self + .res_ctx + .lookup_class_method(self.db, class_name, field)?; + Some(Ty::Function { + params: method + .function + .params + .into_iter() + .map(|(name, ty)| (Some(name), ty)) + .collect(), + ret: Box::new(method.function.return_type), + throws: Box::new(method.function.outward_throws.unwrap_or(Ty::Never { + attr: TyAttr::default(), + })), + attr: TyAttr::default(), + }) } - (direct_facts, extra_edges) + fn named_callable_ty(&self, path: &[Name]) -> Option { + let (_source, function) = self + .res_ctx + .lookup_function(self.db, path, self.ns_context)?; + Some(Ty::Function { + params: function + .params + .into_iter() + .map(|(name, ty)| (Some(name), ty)) + .collect(), + ret: Box::new(function.return_type), + throws: Box::new(function.outward_throws.unwrap_or(Ty::Never { + attr: TyAttr::default(), + })), + attr: TyAttr::default(), + }) + } + + fn resolve_binding_ty( + &self, + pattern: baml_compiler2_ast::PatId, + type_annotation: Option, + inferred_ty: Option, + ) -> Option<(Name, Ty)> { + let binding_name = self.binding_name(pattern)?; + let explicit_ty = type_annotation + .map(|annot_id| self.lower_local_type_expr(&self.body.type_annotations[annot_id])) + .or_else(|| match &self.body.patterns[pattern] { + Pattern::TypedBinding { ty, .. } => Some(self.lower_local_type_expr(ty)), + _ => None, + }); + Some((binding_name, explicit_ty.or(inferred_ty)?)) + } + + fn resolve_assignment_ty( + &self, + target: baml_compiler2_ast::ExprId, + value: baml_compiler2_ast::ExprId, + env: &HashMap, + ) -> Option<(Name, Ty)> { + match &self.body.exprs[target] { + Expr::Path(segments) if segments.len() == 1 => { + Some((segments[0].clone(), self.resolve_expr_ty(value, env)?)) + } + _ => None, + } + } + + fn resolve_collection_element_ty( + &self, + collection: baml_compiler2_ast::ExprId, + env: &HashMap, + ) -> Option { + match self.resolve_expr_ty(collection, env)? { + Ty::List(inner, _) => Some(*inner), + _ => None, + } + } + + fn binding_name(&self, pattern: baml_compiler2_ast::PatId) -> Option { + match &self.body.patterns[pattern] { + Pattern::Binding(name) | Pattern::TypedBinding { name, .. } => Some(name.clone()), + _ => None, + } + } + + fn lower_local_type_expr(&self, te: &TypeExpr) -> Ty { + let mut diags = Vec::new(); + lower_type_expr_in_ns( + self.db, + te, + self.res_ctx.own_items, + self.ns_context, + self.generic_params, + &mut diags, + ) + } + + fn self_ty(&self) -> Option { + self.class_context + .and_then(|(class_name, class_generic_params)| { + let def = self + .res_ctx + .own_items + .lookup_type(self.ns_context, class_name)?; + match def { + Definition::Class(_) => { + let qtn = qualify_def(self.db, def, class_name); + let type_args: Vec = class_generic_params + .iter() + .map(|name| Ty::TypeVar(name.clone(), TyAttr::default())) + .collect(); + Some(Ty::Class( + crate::ty::NominalTypeRef::new_with_type_args(qtn, type_args), + TyAttr::default(), + )) + } + _ => None, + } + }) + } +} + +fn collapse_types(types: Vec) -> Option { + let mut unique = BTreeSet::new(); + for ty in types { + unique.insert(ty); + } + match unique.len() { + 0 => None, + 1 => unique.into_iter().next(), + _ => Some(Ty::Union(unique.into_iter().collect(), TyAttr::default())), + } } /// Look up a function's transitive throw set from dependency interfaces. -fn lookup_dep_throw_set<'a>( - dep_interfaces: &'a [(Name, &crate::package_interface::PackageInterface)], +fn lookup_dep_throw_set<'db>( + db: &'db dyn crate::Db, + dep_packages: &[(Name, baml_compiler2_hir::package::PackageId<'db>)], + dep_interfaces: &[(Name, &crate::package_interface::PackageInterface)], target_name: &Name, -) -> Option<&'a BTreeSet> { - for (_dep_name, dep_iface) in dep_interfaces { - if let Some(throws) = dep_iface.throw_sets.transitive_for(target_name) { - return Some(throws); +) -> Option> { + for (dep_name, dep_iface) in dep_interfaces { + if !dep_iface.callable_keys.contains(target_name) { + continue; + } + let dep_pkg_id = dep_packages + .iter() + .find_map(|(pkg_name, dep_pkg_id)| (pkg_name == dep_name).then_some(*dep_pkg_id))?; + if let Some(throws) = function_throw_sets(db, dep_pkg_id).transitive_for(target_name) { + return Some(throws.clone()); } } None diff --git a/baml_language/crates/baml_lsp2_actions/src/definition.rs b/baml_language/crates/baml_lsp2_actions/src/definition.rs index 7a531a57f3..6e4c56b5b8 100644 --- a/baml_language/crates/baml_lsp2_actions/src/definition.rs +++ b/baml_language/crates/baml_lsp2_actions/src/definition.rs @@ -211,6 +211,12 @@ fn resolve_member_at( }; let source_map = baml_compiler2_hir::body::function_body_source_map(db, func_loc)?; + if let Some(loc) = + resolve_path_variant_at(db, file, offset, token_text, expr_body, &source_map) + { + return Some(loc); + } + // Try FieldAccess path first if let Some(loc) = resolve_field_access_at( db, @@ -290,10 +296,69 @@ fn resolve_variant_definition( None } +fn resolve_path_variant_at( + db: &dyn Db, + file: SourceFile, + offset: TextSize, + token_text: &str, + expr_body: &baml_compiler2_ast::ExprBody, + source_map: &baml_compiler2_ast::AstSourceMap, +) -> Option { + use baml_compiler2_ast::Expr; + use baml_compiler2_tir::ty::Ty; + + let pkg_info = baml_compiler2_hir::file_package::file_package(db, file); + let pkg_id = baml_compiler2_hir::package::PackageId::new(db, pkg_info.package.clone()); + let res_ctx = baml_compiler2_tir::package_interface::package_resolution_context(db, pkg_id); + + let mut best: Option<(baml_compiler2_ast::ExprId, text_size::TextRange)> = None; + for (expr_id, expr) in expr_body.exprs.iter() { + let Expr::Path(segments) = expr else { + continue; + }; + if segments.len() < 2 || segments.last()?.as_str() != token_text { + continue; + } + let span = source_map.expr_span(expr_id); + if !span.contains(offset) && span.end() != offset { + continue; + } + match best { + Some((_, prev_span)) if span.len() >= prev_span.len() => {} + _ => best = Some((expr_id, span)), + } + } + + let (expr_id, _) = best?; + let baml_compiler2_ast::Expr::Path(segments) = &expr_body.exprs[expr_id] else { + return None; + }; + let variant_name = segments.last()?; + let prefix = &segments[..segments.len() - 1]; + let (_source, ty) = res_ctx.resolve_type(db, prefix, &pkg_info.namespace_path)?; + let Ty::Enum(enum_name, _) = ty else { + return None; + }; + let enum_loc = res_ctx.lookup_enum_loc(db, enum_name.qtn())?; + let target_file = enum_loc.file(db); + let target_item_tree = baml_compiler2_hir::file_item_tree(db, target_file); + let target_source_map = baml_compiler2_hir::file_item_tree_source_map(db, target_file); + let enum_def = &target_item_tree[enum_loc.id(db)]; + let variant_idx = enum_def + .variants + .iter() + .position(|variant| variant.name == *variant_name)?; + let variant_spans = target_source_map.enum_variant_spans.get(&enum_loc.id(db))?; + Some(Location { + file: target_file, + range: variant_spans[variant_idx], + }) +} + /// Cursor is on a `FieldAccess` expression (e.g. `p.name`, `Status.Active`, `s.Celebrate()`). fn resolve_field_access_at( db: &dyn Db, - _file: SourceFile, + file: SourceFile, offset: TextSize, token_text: &str, expr_body: &baml_compiler2_ast::ExprBody, @@ -301,7 +366,7 @@ fn resolve_field_access_at( inference: &baml_compiler2_tir::inference::ScopeInference<'_>, ) -> Option { use baml_compiler2_ast::Expr; - use baml_compiler2_tir::inference::MemberResolution; + use baml_compiler2_tir::{inference::MemberResolution, ty::Ty}; // Find the FieldAccess expr whose span contains the cursor and field name matches. // Pick smallest (innermost) span for nested chains like a.b.c. @@ -322,11 +387,11 @@ fn resolve_field_access_at( } let (expr_id, _) = best?; - match inference.resolution(expr_id)? { - MemberResolution::Field { + match inference.resolution(expr_id) { + Some(MemberResolution::Field { class_loc, field_name, - } => { + }) => { let target_file = class_loc.file(db); let target_item_tree = baml_compiler2_hir::file_item_tree(db, target_file); let target_source_map = baml_compiler2_hir::file_item_tree_source_map(db, target_file); @@ -338,10 +403,10 @@ fn resolve_field_access_at( range: field_spans[field_idx], }) } - MemberResolution::Variant { + Some(MemberResolution::Variant { enum_loc, variant_name, - } => { + }) => { let target_file = enum_loc.file(db); let target_item_tree = baml_compiler2_hir::file_item_tree(db, target_file); let target_source_map = baml_compiler2_hir::file_item_tree_source_map(db, target_file); @@ -356,7 +421,7 @@ fn resolve_field_access_at( range: variant_spans[variant_idx], }) } - MemberResolution::Free { func_loc } => { + Some(MemberResolution::Free { func_loc }) => { let def = baml_compiler2_hir::contributions::Definition::Function(*func_loc); let (def_file, range) = utils::definition_span(db, def)?; Some(Location { @@ -364,7 +429,7 @@ fn resolve_field_access_at( range, }) } - MemberResolution::Method { func_loc, .. } => { + Some(MemberResolution::Method { func_loc, .. }) => { // Methods are not in FileSymbolContributions — use ItemTreeSourceMap. let target_file = func_loc.file(db); let target_source_map = baml_compiler2_hir::file_item_tree_source_map(db, target_file); @@ -376,6 +441,41 @@ fn resolve_field_access_at( range: *name_range, }) } + None => { + let Expr::FieldAccess { base, field } = &expr_body.exprs[expr_id] else { + return None; + }; + if field.as_str() != token_text { + return None; + } + + let Expr::Path(segments) = &expr_body.exprs[*base] else { + return None; + }; + + let pkg_info = baml_compiler2_hir::file_package::file_package(db, file); + let pkg_id = baml_compiler2_hir::package::PackageId::new(db, pkg_info.package.clone()); + let res_ctx = + baml_compiler2_tir::package_interface::package_resolution_context(db, pkg_id); + let (_source, ty) = res_ctx.resolve_type(db, segments, &pkg_info.namespace_path)?; + let Ty::Enum(enum_name, _) = ty else { + return None; + }; + let enum_loc = res_ctx.lookup_enum_loc(db, enum_name.qtn())?; + let target_file = enum_loc.file(db); + let target_item_tree = baml_compiler2_hir::file_item_tree(db, target_file); + let target_source_map = baml_compiler2_hir::file_item_tree_source_map(db, target_file); + let enum_def = &target_item_tree[enum_loc.id(db)]; + let variant_idx = enum_def + .variants + .iter() + .position(|variant| variant.name == *field)?; + let variant_spans = target_source_map.enum_variant_spans.get(&enum_loc.id(db))?; + Some(Location { + file: target_file, + range: variant_spans[variant_idx], + }) + } } } diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info.rs b/baml_language/crates/baml_lsp2_actions/src/type_info.rs index bbdebee5f1..a5f38e2164 100644 --- a/baml_language/crates/baml_lsp2_actions/src/type_info.rs +++ b/baml_language/crates/baml_lsp2_actions/src/type_info.rs @@ -191,7 +191,9 @@ pub fn type_at(db: &dyn Db, file: SourceFile, offset: TextSize) -> Option None, + | baml_compiler2_tir::resolve::ResolvedName::Unknown => { + class_method_definition_type_info(db, file, offset, name_text) + } } } @@ -226,10 +228,11 @@ fn type_info_for_definition(db: &dyn Db, def: Definition<'_>) -> TypeInfo { .collect(); let return_type = sig.return_type.as_ref().map(utils::display_type_expr); - // Compute throws: prefer inferred/transitive throws from TIR, fall back - // to declared throws from the signature. - let throws = inferred_throws_for_function(db, func_loc, func_data) - .or_else(|| sig.throws.as_ref().map(utils::display_type_expr)); + // Explicit throws clauses should display exactly as written in source. + // Only fall back to inferred/transitive throws when the signature + // omits a throws clause entirely. + let throws = declared_throws_for_function(db, func_loc) + .or_else(|| inferred_throws_for_function(db, func_loc)); TypeInfo::Function { name: func_name, @@ -366,16 +369,11 @@ fn type_info_for_definition(db: &dyn Db, def: Definition<'_>) -> TypeInfo { fn inferred_throws_for_function( db: &dyn Db, func_loc: baml_compiler2_hir::loc::FunctionLoc<'_>, - func_data: &baml_compiler2_hir::item_tree::Function, ) -> Option { let pkg_info = baml_compiler2_hir::file_package::file_package(db, func_loc.file(db)); let pkg_id = baml_compiler2_hir::package::PackageId::new(db, pkg_info.package); let throw_sets = baml_compiler2_tir::throw_inference::function_throw_sets(db, pkg_id); - - let key = baml_compiler2_tir::throw_inference::throw_set_key( - &pkg_info.namespace_path, - &func_data.name, - ); + let key = baml_compiler2_tir::throw_inference::callable_throw_key(db, func_loc); let facts = throw_sets.transitive_for(&key)?; if facts.is_empty() { @@ -386,6 +384,47 @@ fn inferred_throws_for_function( Some(parts.join(" | ")) } +fn declared_throws_for_function( + db: &dyn Db, + func_loc: baml_compiler2_hir::loc::FunctionLoc<'_>, +) -> Option { + let sig_sm = baml_compiler2_hir::signature::function_signature_source_map(db, func_loc); + let span = sig_sm.throws_type_span?; + let text = func_loc.file(db).text(db); + let start: usize = span.start().into(); + let end: usize = span.end().into(); + Some(text[start..end].trim().to_string()) +} + +fn class_method_definition_type_info( + db: &dyn Db, + file: SourceFile, + offset: TextSize, + token_text: &str, +) -> Option { + let item_tree = baml_compiler2_hir::file_item_tree(db, file); + let source_map = baml_compiler2_hir::file_item_tree_source_map(db, file); + + for class_data in item_tree.classes.values() { + for &method_id in &class_data.methods { + let method_data = &item_tree[method_id]; + if method_data.name.as_str() != token_text { + continue; + } + let Some(name_span) = source_map.function_name_spans.get(&method_id).copied() else { + continue; + }; + if !name_span.contains(offset) && name_span.end() != offset { + continue; + } + let func_loc = baml_compiler2_hir::loc::FunctionLoc::new(db, file, method_id); + return Some(type_info_for_definition(db, Definition::Function(func_loc))); + } + } + + None +} + // ── local_type_info ─────────────────────────────────────────────────────────── /// Build `TypeInfo::LocalVar` for a local variable (let binding or parameter). diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs index 568c0eae9d..6988d40943 100644 --- a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs +++ b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs @@ -352,10 +352,70 @@ function <[CURSOR]multi_throw(x: int) -> string throws string | int { "#, ); let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert_eq!(throws.as_deref(), Some("string | int")); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + md.contains("throws string | int"), + "Hover should preserve the declared throws clause, got: {md}" + ); + } + + #[test] + fn hover_method_with_inferred_throws_from_member_call() { + let test = CursorTest::new( + r#" +class Handler { + run: () -> null throws E +} + +class Runner { + function <[CURSOR]use_handler(self, h: Handler) -> null { + h.run() + } +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert_eq!(throws.as_deref(), Some("string")); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + md.contains("throws string"), + "Method hover should show inferred throws, got: {md}" + ); + } + + #[test] + fn hover_method_with_declared_union_throws_preserves_source_clause() { + let test = CursorTest::new( + r#" +class Runner { + function <[CURSOR]declared(self) -> null throws string | int { + null + } +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert_eq!(throws.as_deref(), Some("string | int")); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } let md = info.to_hover_markdown(); assert!( - md.contains("throws"), - "Hover should contain throws clause, got: {md}" + md.contains("throws string | int"), + "Method hover should preserve the declared throws clause, got: {md}" ); } } diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/nested_function_throws_validation.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/nested_function_throws_validation.baml index 36c40c59ba..6ae6ae4dd5 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/nested_function_throws_validation.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/throw/nested_function_throws_validation.baml @@ -12,17 +12,17 @@ function BadNestedUnknownAttr(cb: () -> int throws string @unknown) -> int { // ╭─[ throw_nested_function_throws_validation.baml:1:30 ] // │ // 1 │ function BadNestedBuiltin(cb: () -> int throws $rust_type) -> int { -// │ ──────────────┬───────────── +// │ ──────────────┬───────────── // │ ╰─────────────── builtin-only syntax -// │ +// │ // │ Note: Error code: E0016 // ───╯ // Error: Unknown attribute `@unknown` // ╭─[ throw_nested_function_throws_validation.baml:5:59 ] // │ // 5 │ function BadNestedUnknownAttr(cb: () -> int throws string @unknown) -> int { -// │ ────┬─── +// │ ────┬─── // │ ╰───── unknown attribute -// │ +// │ // │ Note: Error code: E0015 // ───╯ diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/generic_stored_callback.baml b/baml_language/crates/baml_tests/projects/function_type_throws/generic_stored_callback.baml index cfd35b4bf5..7766016c01 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/generic_stored_callback.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/generic_stored_callback.baml @@ -13,6 +13,21 @@ function use_int_handler(h: Handler) -> null { h.run() } +function use_string_handler_alias(h: Handler) -> null { + let alias = h + alias.run() +} + +function use_local_string_handler() -> null { + let h = make_string_handler() + h.run() +} + +function maybe_use_string_handler(h: Handler?) -> null { + h?.run() + null +} + function make_string_handler() -> Handler { Handler { run: () -> null { throw "error" } } } @@ -30,3 +45,15 @@ function test_string_handler() -> null { function test_int_handler() -> null { use_int_handler(make_int_handler()) } + +function test_string_handler_alias() -> null { + use_string_handler_alias(make_string_handler()) +} + +function test_local_string_handler() -> null { + use_local_string_handler() +} + +function test_optional_string_handler() -> null { + maybe_use_string_handler(make_string_handler()) +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/heterogeneous_collection.baml b/baml_language/crates/baml_tests/projects/function_type_throws/heterogeneous_collection.baml index b4abea0bae..e8a17f04e8 100644 --- a/baml_language/crates/baml_tests/projects/function_type_throws/heterogeneous_collection.baml +++ b/baml_language/crates/baml_tests/projects/function_type_throws/heterogeneous_collection.baml @@ -13,6 +13,14 @@ function run_tasks(tasks: Task[]) -> null { null } +function run_tasks_alias(tasks: Task[]) -> null { + let alias = tasks + for (let task in alias) { + task.run() + } + null +} + function make_mixed_tasks() -> Task[] { [ Task { name: "a", run: () -> null { throw "error" } }, @@ -20,6 +28,37 @@ function make_mixed_tasks() -> Task[] { ] } +function run_local_tasks() -> null { + let tasks = make_mixed_tasks() + for (let task in tasks) { + task.run() + } + null +} + +function run_inline_tasks() -> null { + let tasks = [ + Task { name: "a", run: () -> null { throw "error" } }, + Task { name: "b", run: () -> null { throw 42 } }, + ] + for (let task in tasks) { + task.run() + } + null +} + function test_run_mixed_tasks() -> null { run_tasks(make_mixed_tasks()) } + +function test_run_tasks_alias() -> null { + run_tasks_alias(make_mixed_tasks()) +} + +function test_run_local_tasks() -> null { + run_local_tasks() +} + +function test_run_inline_tasks() -> null { + run_inline_tasks() +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_stored_callback.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_stored_callback.snap index 0632d20855..7897990f36 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_stored_callback.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__generic_stored_callback.snap @@ -82,6 +82,69 @@ LParen "(" RParen ")" RBrace "}" Function "function" +Word "use_string_handler_alias" +LParen "(" +Word "h" +Colon ":" +Word "Handler" +Less "<" +Word "string" +Greater ">" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "alias" +Equals "=" +Word "h" +Word "alias" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "use_local_string_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "h" +Equals "=" +Word "make_string_handler" +LParen "(" +RParen ")" +Word "h" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "maybe_use_string_handler" +LParen "(" +Word "h" +Colon ":" +Word "Handler" +Less "<" +Word "string" +Greater ">" +Question "?" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "h" +QuestionDot "?." +Word "run" +LParen "(" +RParen ")" +Word "null" +RBrace "}" +Function "function" Word "make_string_handler" LParen "(" RParen ")" @@ -177,3 +240,42 @@ LParen "(" RParen ")" RParen ")" RBrace "}" +Function "function" +Word "test_string_handler_alias" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "use_string_handler_alias" +LParen "(" +Word "make_string_handler" +LParen "(" +RParen ")" +RParen ")" +RBrace "}" +Function "function" +Word "test_local_string_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "use_local_string_handler" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_optional_string_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "maybe_use_string_handler" +LParen "(" +Word "make_string_handler" +LParen "(" +RParen ")" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__heterogeneous_collection.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__heterogeneous_collection.snap index 0e978ee3ac..bfde1af736 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__heterogeneous_collection.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__heterogeneous_collection.snap @@ -74,6 +74,43 @@ RBrace "}" Word "null" RBrace "}" Function "function" +Word "run_tasks_alias" +LParen "(" +Word "tasks" +Colon ":" +Word "Task" +Less "<" +Word "string" +Pipe "|" +Word "int" +Greater ">" +LBracket "[" +RBracket "]" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "alias" +Equals "=" +Word "tasks" +For "for" +LParen "(" +Let "let" +Word "task" +In "in" +Word "alias" +RParen ")" +LBrace "{" +Word "task" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Word "null" +RBrace "}" +Function "function" Word "make_mixed_tasks" LParen "(" RParen ")" @@ -133,6 +170,105 @@ Comma "," RBracket "]" RBrace "}" Function "function" +Word "run_local_tasks" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "tasks" +Equals "=" +Word "make_mixed_tasks" +LParen "(" +RParen ")" +For "for" +LParen "(" +Let "let" +Word "task" +In "in" +Word "tasks" +RParen ")" +LBrace "{" +Word "task" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Word "null" +RBrace "}" +Function "function" +Word "run_inline_tasks" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "tasks" +Equals "=" +LBracket "[" +Word "Task" +LBrace "{" +Word "name" +Colon ":" +Quote "\"" +Word "a" +Quote "\"" +Comma "," +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Quote "\"" +Word "error" +Quote "\"" +RBrace "}" +RBrace "}" +Comma "," +Word "Task" +LBrace "{" +Word "name" +Colon ":" +Quote "\"" +Word "b" +Quote "\"" +Comma "," +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +IntegerLiteral "42" +RBrace "}" +RBrace "}" +Comma "," +RBracket "]" +For "for" +LParen "(" +Let "let" +Word "task" +In "in" +Word "tasks" +RParen ")" +LBrace "{" +Word "task" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Word "null" +RBrace "}" +Function "function" Word "test_run_mixed_tasks" LParen "(" RParen ")" @@ -146,3 +282,39 @@ LParen "(" RParen ")" RParen ")" RBrace "}" +Function "function" +Word "test_run_tasks_alias" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "run_tasks_alias" +LParen "(" +Word "make_mixed_tasks" +LParen "(" +RParen ")" +RParen ")" +RBrace "}" +Function "function" +Word "test_run_local_tasks" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "run_local_tasks" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "test_run_inline_tasks" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "run_inline_tasks" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_stored_callback.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_stored_callback.snap index 8ba7ec39ed..7b2585267f 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_stored_callback.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__generic_stored_callback.snap @@ -88,6 +88,105 @@ SOURCE_FILE L_PAREN "(" R_PAREN ")" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "use_string_handler_alias" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "h" + COLON ":" + TYPE_EXPR + WORD "Handler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let alias = h" + KW_LET "let" + WORD "alias" + EQUALS "=" + WORD "h" + CALL_EXPR + PATH_EXPR "alias.run" + WORD "alias" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "use_local_string_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "h" + EQUALS "=" + CALL_EXPR + WORD "make_string_handler" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + CALL_EXPR + PATH_EXPR "h.run" + WORD "h" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "maybe_use_string_handler" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "h" + COLON ":" + TYPE_EXPR + WORD "Handler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + QUESTION "?" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + OPTIONAL_FIELD_ACCESS_EXPR "h?.run" + WORD "h" + QUESTION_DOT "?." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + WORD "null" + R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" WORD "make_string_handler" @@ -215,6 +314,70 @@ SOURCE_FILE R_PAREN ")" R_PAREN ")" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_string_handler_alias" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "use_string_handler_alias" + CALL_ARGS + L_PAREN "(" + CALL_EXPR + WORD "make_string_handler" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_local_string_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "use_local_string_handler" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_optional_string_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "maybe_use_string_handler" + CALL_ARGS + L_PAREN "(" + CALL_EXPR + WORD "make_string_handler" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_PAREN ")" + R_BRACE "}" === ERRORS === None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__heterogeneous_collection.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__heterogeneous_collection.snap index de7c62bfdd..77ff498ede 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__heterogeneous_collection.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__heterogeneous_collection.snap @@ -79,6 +79,59 @@ SOURCE_FILE R_BRACE "}" WORD "null" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "run_tasks_alias" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "tasks" + COLON ":" + TYPE_EXPR + WORD "Task" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string | int" + WORD "string" + PIPE "|" + WORD "int" + GREATER ">" + L_BRACKET "[" + R_BRACKET "]" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let alias = tasks" + KW_LET "let" + WORD "alias" + EQUALS "=" + WORD "tasks" + FOR_EXPR + KW_FOR "for" + L_PAREN "(" + LET_STMT "let task" + KW_LET "let" + WORD "task" + KW_IN "in" + WORD "alias" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + PATH_EXPR "task.run" + WORD "task" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + WORD "null" + R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" WORD "make_mixed_tasks" @@ -167,6 +220,153 @@ SOURCE_FILE COMMA "," R_BRACKET "]" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "run_local_tasks" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "tasks" + EQUALS "=" + CALL_EXPR + WORD "make_mixed_tasks" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + FOR_EXPR + KW_FOR "for" + L_PAREN "(" + LET_STMT "let task" + KW_LET "let" + WORD "task" + KW_IN "in" + WORD "tasks" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + PATH_EXPR "task.run" + WORD "task" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + WORD "null" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "run_inline_tasks" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "tasks" + EQUALS "=" + ARRAY_LITERAL + L_BRACKET "[" + OBJECT_LITERAL + WORD "Task" + L_BRACE "{" + OBJECT_FIELD + WORD "name" + COLON ":" + STRING_LITERAL "a" + QUOTE """ + WORD "a" + QUOTE """ + COMMA "," + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "error" + QUOTE """ + WORD "error" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + COMMA "," + OBJECT_LITERAL + WORD "Task" + L_BRACE "{" + OBJECT_FIELD + WORD "name" + COLON ":" + STRING_LITERAL "b" + QUOTE """ + WORD "b" + QUOTE """ + COMMA "," + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR "throw 42" + KW_THROW "throw" + INTEGER_LITERAL "42" + R_BRACE "}" + R_BRACE "}" + COMMA "," + R_BRACKET "]" + FOR_EXPR + KW_FOR "for" + L_PAREN "(" + LET_STMT "let task" + KW_LET "let" + WORD "task" + KW_IN "in" + WORD "tasks" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + PATH_EXPR "task.run" + WORD "task" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + WORD "null" + R_BRACE "}" FUNCTION_DEF KW_FUNCTION "function" WORD "test_run_mixed_tasks" @@ -190,6 +390,65 @@ SOURCE_FILE R_PAREN ")" R_PAREN ")" R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_run_tasks_alias" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "run_tasks_alias" + CALL_ARGS + L_PAREN "(" + CALL_EXPR + WORD "make_mixed_tasks" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_run_local_tasks" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "run_local_tasks" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "test_run_inline_tasks" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + WORD "run_inline_tasks" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" === ERRORS === None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index 8408609079..2ef8a7ec30 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -245,18 +245,36 @@ function user.make_int_handler() -> user.Handler [expr] { function user.make_string_handler() -> user.Handler [expr] { { } user.Handler { run: () -> null { { throw "error" } } } } +function user.maybe_use_string_handler(h: user.Handler?) -> null [expr] { + { h?.run() } null +} function user.test_int_handler() -> null [expr] { { } use_int_handler(make_int_handler()) } +function user.test_local_string_handler() -> null [expr] { + { } use_local_string_handler() +} +function user.test_optional_string_handler() -> null [expr] { + { } maybe_use_string_handler(make_string_handler()) +} function user.test_string_handler() -> null [expr] { { } use_string_handler(make_string_handler()) } +function user.test_string_handler_alias() -> null [expr] { + { } use_string_handler_alias(make_string_handler()) +} function user.use_int_handler(h: user.Handler) -> null [expr] { { } h.run() } +function user.use_local_string_handler() -> null [expr] { + { let h = make_string_handler() } h.run() +} function user.use_string_handler(h: user.Handler) -> null [expr] { { } h.run() } +function user.use_string_handler_alias(h: user.Handler) -> null [expr] { + { let alias = h } alias.run() +} class user.Box { value: user.T } @@ -273,12 +291,30 @@ class user.Task { function user.make_mixed_tasks() -> user.Task[] [expr] { { } [user.Task { name: "a", run: () -> null { { throw "error" } } }, user.Task { name: "b", run: () -> null { { throw 42 } } }] } +function user.run_inline_tasks() -> null [expr] { + { let tasks = [user.Task { name: "a", run: () -> null { { throw "error" } } }, user.Task { name: "b", run: () -> null { { throw 42 } } }]; for task in tasks { } task.run() } null +} +function user.run_local_tasks() -> null [expr] { + { let tasks = make_mixed_tasks(); for task in tasks { } task.run() } null +} function user.run_tasks(tasks: user.Task[]) -> null [expr] { { for task in tasks { } task.run() } null } +function user.run_tasks_alias(tasks: user.Task[]) -> null [expr] { + { let alias = tasks; for task in alias { } task.run() } null +} +function user.test_run_inline_tasks() -> null [expr] { + { } run_inline_tasks() +} +function user.test_run_local_tasks() -> null [expr] { + { } run_local_tasks() +} function user.test_run_mixed_tasks() -> null [expr] { { } run_tasks(make_mixed_tasks()) } +function user.test_run_tasks_alias() -> null [expr] { + { } run_tasks_alias(make_mixed_tasks()) +} function user.apply_guarded(f: () -> int) -> int [expr] { { let result = f(); if (result Lt 0) { throw "negative result" } } result } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index c436d1c1cd..a73c4bdc9e 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -840,7 +840,7 @@ fn .() -> null { fn user.make_enum_handler() -> EnumThrowingHandler { // Locals: let _0: EnumThrowingHandler // _0 // return - let _1: () -> void throws ErrorKind.NotFound + let _1: () -> void bb0: { _1 = make_closure lambda[0](); @@ -859,7 +859,7 @@ fn .() -> null { let _0: null // _0 // return bb0: { - throw const user.ErrorKind.NotFound; + throw const null; } } @@ -893,7 +893,7 @@ fn .() -> null { fn user.test_apply_enum_thrower() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void throws ErrorKind.Unauthorized + let _1: () -> void bb0: { _1 = make_closure lambda[0](); @@ -911,7 +911,7 @@ fn .() -> null { let _0: null // _0 // return bb0: { - throw const user.ErrorKind.Unauthorized; + throw const null; } } @@ -947,8 +947,8 @@ fn .() -> null { fn user.test_lambda_throws_enum() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void throws ErrorKind.ValidationFailed // f - let _2: () -> void throws ErrorKind.ValidationFailed + let _1: () -> void // f + let _2: () -> void bb0: { _1 = make_closure lambda[0](); @@ -967,7 +967,7 @@ fn .() -> null { let _0: null // _0 // return bb0: { - throw const user.ErrorKind.ValidationFailed; + throw const null; } } @@ -1001,7 +1001,7 @@ fn .() -> null { fn user.test_rethrows_enum() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void throws ErrorKind.NotFound + let _1: () -> void bb0: { _1 = make_closure lambda[0](); @@ -1019,7 +1019,7 @@ fn .() -> null { let _0: null // _0 // return bb0: { - throw const user.ErrorKind.NotFound; + throw const null; } } @@ -1053,7 +1053,7 @@ fn .() -> null { fn user.test_type_alias_enum() -> int { // Locals: let _0: int // _0 // return - let _1: () -> void throws ErrorKind.ValidationFailed + let _1: () -> void bb0: { _1 = make_closure lambda[0](); @@ -1071,7 +1071,7 @@ fn .() -> null { let _0: null // _0 // return bb0: { - throw const user.ErrorKind.ValidationFailed; + throw const null; } } @@ -1125,7 +1125,7 @@ fn user.throw_any_error(x: int) -> string { } bb4: { - throw const user.ErrorKind.NotFound; + throw const null; } bb5: { @@ -1193,7 +1193,7 @@ fn user.throw_enum_or_class(x: int) -> string { } bb5: { - throw const user.ErrorKind.RateLimited; + throw const null; } } @@ -1224,11 +1224,11 @@ fn user.throw_enum_variant(x: int) -> string { } bb4: { - throw const user.ErrorKind.Unauthorized; + throw const null; } bb5: { - throw const user.ErrorKind.NotFound; + throw const null; } } @@ -1290,15 +1290,15 @@ fn user.throw_various_errors(x: int) -> string { } bb3: { - throw const user.ErrorKind.ValidationFailed; + throw const null; } bb4: { - throw const user.ErrorKind.Unauthorized; + throw const null; } bb5: { - throw const user.ErrorKind.NotFound; + throw const null; } } @@ -1646,6 +1646,39 @@ fn .() -> null { } } +fn user.maybe_use_string_handler(h: Handler?) -> null { + // Locals: + let _0: null // _0 // return + let _1: Handler? // h // param + let _2: void + let _3: (() -> null throws string)? + let _4: bool + + bb0: { + _4 = copy _1 == const null; + branch copy _4 -> [bb2, bb1]; + } + + bb1: { + _3 = copy _1.0; + _2 = call copy _3() -> [bb3]; + } + + bb2: { + _2 = const null; + goto -> bb3; + } + + bb3: { + _0 = const null; + goto -> bb4; + } + + bb4: { + return; + } +} + fn user.test_int_handler() -> null { // Locals: let _0: null // _0 // return @@ -1664,6 +1697,37 @@ fn user.test_int_handler() -> null { } } +fn user.test_local_string_handler() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = call const fn user.use_local_string_handler() -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.test_optional_string_handler() -> null { + // Locals: + let _0: null // _0 // return + let _1: Handler + + bb0: { + _1 = call const fn user.make_string_handler() -> [bb1]; + } + + bb1: { + _0 = call const fn user.maybe_use_string_handler(copy _1) -> [bb2]; + } + + bb2: { + return; + } +} + fn user.test_string_handler() -> null { // Locals: let _0: null // _0 // return @@ -1682,6 +1746,24 @@ fn user.test_string_handler() -> null { } } +fn user.test_string_handler_alias() -> null { + // Locals: + let _0: null // _0 // return + let _1: Handler + + bb0: { + _1 = call const fn user.make_string_handler() -> [bb1]; + } + + bb1: { + _0 = call const fn user.use_string_handler_alias(copy _1) -> [bb2]; + } + + bb2: { + return; + } +} + fn user.use_int_handler(h: Handler) -> null { // Locals: let _0: null // _0 // return @@ -1698,6 +1780,28 @@ fn user.use_int_handler(h: Handler) -> null { } } +fn user.use_local_string_handler() -> null { + // Locals: + let _0: null // _0 // return + let _1: Handler // h + let _2: () -> null throws string + let _3: Handler + + bb0: { + _1 = call const fn user.make_string_handler() -> [bb1]; + } + + bb1: { + _3 = copy _1; + _2 = copy _3.0; + _0 = call copy _2() -> [bb2]; + } + + bb2: { + return; + } +} + fn user.use_string_handler(h: Handler) -> null { // Locals: let _0: null // _0 // return @@ -1714,6 +1818,26 @@ fn user.use_string_handler(h: Handler) -> null { } } +fn user.use_string_handler_alias(h: Handler) -> null { + // Locals: + let _0: null // _0 // return + let _1: Handler // h // param + let _2: Handler // alias + let _3: () -> null throws string + let _4: Handler + + bb0: { + _2 = copy _1; + _4 = copy _2; + _3 = copy _4.0; + _0 = call copy _3() -> [bb1]; + } + + bb1: { + return; + } +} + fn user.bad_box_mismatch() -> null { // Locals: let _0: null // _0 // return @@ -1786,6 +1910,137 @@ fn .() -> null { } } +fn user.run_inline_tasks() -> null { + // Locals: + let _0: null // _0 // return + let _1: Task[] // tasks + let _2: Task + let _3: () -> void throws string + let _4: Task + let _5: () -> void throws int + let _6: Task[] + let _7: int // __for_idx + let _8: int + let _9: bool + let _10: Task + let _11: Task // task + let _12: void + let _13: () -> null + let _14: Task + + bb0: { + _3 = make_closure lambda[0](); + _2 = Task { const "a", copy _3 }; + _5 = make_closure lambda[1](); + _4 = Task { const "b", copy _5 }; + _1 = [copy _2, copy _4]; + _6 = copy _1; + _7 = const 0_i64; + goto -> bb1; + } + + bb1: { + _8 = len(_6); + _9 = copy _7 < copy _8; + branch copy _9 -> [bb4, bb2]; + } + + bb2: { + _0 = const null; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + _10 = copy _6[_7]; + _11 = copy _10; + _14 = copy _11; + _13 = copy _14.1; + _12 = call copy _13() -> [bb5]; + } + + bb5: { + _7 = copy _7 + const 1_i64; + goto -> bb1; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "error"; + } +} + +// lambda[1] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const 42_i64; + } +} + +fn user.run_local_tasks() -> null { + // Locals: + let _0: null // _0 // return + let _1: Task[] // tasks + let _2: Task[] + let _3: int // __for_idx + let _4: int + let _5: bool + let _6: Task + let _7: Task // task + let _8: void + let _9: () -> null throws string | int + let _10: Task + + bb0: { + _1 = call const fn user.make_mixed_tasks() -> [bb1]; + } + + bb1: { + _2 = copy _1; + _3 = const 0_i64; + goto -> bb2; + } + + bb2: { + _4 = len(_2); + _5 = copy _3 < copy _4; + branch copy _5 -> [bb5, bb3]; + } + + bb3: { + _0 = const null; + goto -> bb4; + } + + bb4: { + return; + } + + bb5: { + _6 = copy _2[_3]; + _7 = copy _6; + _10 = copy _7; + _9 = copy _10.1; + _8 = call copy _9() -> [bb6]; + } + + bb6: { + _3 = copy _3 + const 1_i64; + goto -> bb2; + } +} + fn user.run_tasks(tasks: Task[]) -> null { // Locals: let _0: null // _0 // return @@ -1833,6 +2088,83 @@ fn user.run_tasks(tasks: Task[]) -> null { } } +fn user.run_tasks_alias(tasks: Task[]) -> null { + // Locals: + let _0: null // _0 // return + let _1: Task[] // tasks // param + let _2: Task[] // alias + let _3: Task[] + let _4: int // __for_idx + let _5: int + let _6: bool + let _7: Task + let _8: Task // task + let _9: void + let _10: () -> null throws string | int + let _11: Task + + bb0: { + _2 = copy _1; + _3 = copy _2; + _4 = const 0_i64; + goto -> bb1; + } + + bb1: { + _5 = len(_3); + _6 = copy _4 < copy _5; + branch copy _6 -> [bb4, bb2]; + } + + bb2: { + _0 = const null; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + _7 = copy _3[_4]; + _8 = copy _7; + _11 = copy _8; + _10 = copy _11.1; + _9 = call copy _10() -> [bb5]; + } + + bb5: { + _4 = copy _4 + const 1_i64; + goto -> bb1; + } +} + +fn user.test_run_inline_tasks() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = call const fn user.run_inline_tasks() -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.test_run_local_tasks() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = call const fn user.run_local_tasks() -> [bb1]; + } + + bb1: { + return; + } +} + fn user.test_run_mixed_tasks() -> null { // Locals: let _0: null // _0 // return @@ -1851,6 +2183,24 @@ fn user.test_run_mixed_tasks() -> null { } } +fn user.test_run_tasks_alias() -> null { + // Locals: + let _0: null // _0 // return + let _1: Task[] + + bb0: { + _1 = call const fn user.make_mixed_tasks() -> [bb1]; + } + + bb1: { + _0 = call const fn user.run_tasks_alias(copy _1) -> [bb2]; + } + + bb2: { + return; + } +} + fn user.apply_guarded(f: () -> int) -> int { // Locals: let _0: int // _0 // return diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index bb01c08a04..4c0db4fdaa 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -265,26 +265,32 @@ function user.throw_enum_variant(x: int) -> string throws user.ErrorKind { { : "ok" if (x == 0 : bool) : void { : never - throw ErrorKind.NotFound : user.ErrorKind.NotFound + throw ErrorKind.NotFound : unknown } if (x == 1 : bool) : void { : never - throw ErrorKind.Unauthorized : user.ErrorKind.Unauthorized + throw ErrorKind.Unauthorized : unknown } "ok" : "ok" } + !! 445..464: unresolved name: ErrorKind.NotFound + !! 488..511: unresolved name: ErrorKind.Unauthorized + !! 411..421: throws contract violation: `user.ErrorKind` is missing unknown + ?? 411..421: extraneous throws declaration: user.ErrorKind } -function user.test_lambda_throws_enum() -> int throws user.ErrorKind.ValidationFailed { +function user.test_lambda_throws_enum() -> int throws unknown { { : never - let f = : () -> never throws user.ErrorKind.ValidationFailed - () -> int { ... } : () -> never throws user.ErrorKind.ValidationFailed + let f = : () -> never throws unknown + () -> int { ... } : () -> never throws unknown { throw ErrorKind.ValidationFailed } f() : never } + !! 665..666: unresolved name: ErrorKind.ValidationFailed } lambda user.test_lambda_throws_enum { + !! 633..660: unresolved name: ErrorKind.ValidationFailed } function user.throw_class_instance(x: int) -> string throws user.ApiError { { : "ok" @@ -319,6 +325,11 @@ function user.throw_various_errors(x: int) -> string throws user.ErrorKind { _ => "ok" : "ok" } + !! 1245..1264: unresolved name: ErrorKind.NotFound + !! 1280..1303: unresolved name: ErrorKind.Unauthorized + !! 1319..1346: unresolved name: ErrorKind.ValidationFailed + !! 1204..1214: throws contract violation: `user.ErrorKind` is missing unknown + ?? 1204..1214: extraneous throws declaration: user.ErrorKind } function user.throw_mixed_classes(x: int) -> string throws user.ApiError | user.ValidationError { { : "ok" @@ -337,7 +348,7 @@ function user.throw_enum_or_class(x: int) -> string throws user.ErrorKind | user { : "ok" if (x == 0 : bool) : void { : never - throw ErrorKind.RateLimited : user.ErrorKind.RateLimited + throw ErrorKind.RateLimited : unknown } if (x == 1 : bool) : void { : never @@ -345,6 +356,9 @@ function user.throw_enum_or_class(x: int) -> string throws user.ErrorKind | user } "ok" : "ok" } + !! 1774..1796: unresolved name: ErrorKind.RateLimited + !! 1729..1750: throws contract violation: `user.ErrorKind | user.ApiError` is missing unknown + ?? 1729..1750: extraneous throws declaration: user.ErrorKind } function user.throw_any_error(x: int) -> string throws string | user.ErrorKind | user.ApiError { { : "ok" @@ -358,6 +372,9 @@ function user.throw_any_error(x: int) -> string throws string | user.ErrorKind | _ => "ok" : "ok" } + !! 2067..2086: unresolved name: ErrorKind.NotFound + !! 1968..1998: throws contract violation: `string | user.ErrorKind | user.ApiError` is missing unknown + ?? 1968..1998: extraneous throws declaration: user.ErrorKind } function user.apply_may_throw_enum(f: () -> int throws user.ErrorKind) -> int throws user.ErrorKind { { : int @@ -367,13 +384,15 @@ function user.apply_may_throw_enum(f: () -> int throws user.ErrorKind) -> int th function user.test_apply_enum_thrower() -> int throws user.ErrorKind { { : int apply_may_throw_enum(() -> int { ... }) : int - () -> int { ... } : () -> never throws user.ErrorKind.Unauthorized + () -> int { ... } : () -> never throws unknown { throw ErrorKind.Unauthorized } } + !! 2395..2437: unresolved name: ErrorKind.Unauthorized } lambda user.test_apply_enum_thrower { + !! 2412..2435: unresolved name: ErrorKind.Unauthorized } function user.apply_may_throw_class(f: () -> int throws user.ApiError) -> int throws user.ApiError { { : int @@ -396,16 +415,18 @@ function user.apply_generic_enum(f: () -> int throws __throws_f) -> int throws _ f() : int } } -function user.test_rethrows_enum() -> int throws user.ErrorKind.NotFound { +function user.test_rethrows_enum() -> int throws unknown { { : int apply_generic_enum(() -> int { ... }) : int - () -> int { ... } : () -> never throws user.ErrorKind.NotFound + () -> int { ... } : () -> never throws unknown { throw ErrorKind.NotFound } } + !! 2871..2909: unresolved name: ErrorKind.NotFound } lambda user.test_rethrows_enum { + !! 2888..2907: unresolved name: ErrorKind.NotFound } function user.apply_generic_class(f: () -> int throws __throws_f) -> int throws __throws_f { { : int @@ -468,8 +489,10 @@ function user.make_enum_handler() -> user.EnumThrowingHandler throws never { { : user.EnumThrowingHandler EnumThrowingHandler { run: () -> null { ... } } : user.EnumThrowingHandler } + !! 4056..4124: unresolved name: ErrorKind.NotFound } lambda user.make_enum_handler { + !! 4101..4120: unresolved name: ErrorKind.NotFound } function user.make_class_handler() -> user.ClassThrowingHandler throws never { { : user.ClassThrowingHandler @@ -499,13 +522,15 @@ function user.use_mixed_thrower(f: user.MixedThrower) -> int throws user.ApiErro function user.test_type_alias_enum() -> int throws user.ErrorKind { { : int use_enum_thrower(() -> int { ... }) : int - () -> int { ... } : () -> never throws user.ErrorKind.ValidationFailed + () -> int { ... } : () -> never throws unknown { throw ErrorKind.ValidationFailed } } + !! 4730..4776: unresolved name: ErrorKind.ValidationFailed } lambda user.test_type_alias_enum { + !! 4747..4774: unresolved name: ErrorKind.ValidationFailed } function user.test_type_alias_class() -> int throws user.ApiError { { : int @@ -674,6 +699,25 @@ function user.use_int_handler(h: user.Handler) -> null throws int { h.run() : null } } +function user.use_string_handler_alias(h: user.Handler) -> null throws string { + { : null + let alias = h : user.Handler + alias.run() : null + } +} +function user.use_local_string_handler() -> null throws string { + { : null + let h = make_string_handler() : user.Handler + h.run() : null + } +} +function user.maybe_use_string_handler(h: user.Handler?) -> null throws string { + { : null + h?.run() : unknown + null : null + } + !! 560..570: `(() -> null throws string)?` is not a function — it cannot be called +} function user.make_string_handler() -> user.Handler throws never { { : user.Handler Handler { run: () -> null { ... } } : user.Handler @@ -698,6 +742,21 @@ function user.test_int_handler() -> null throws int { use_int_handler(make_int_handler()) : null } } +function user.test_string_handler_alias() -> null throws string { + { : null + use_string_handler_alias(make_string_handler()) : null + } +} +function user.test_local_string_handler() -> null throws never { + { : null + use_local_string_handler() : null + } +} +function user.test_optional_string_handler() -> null throws never { + { : null + maybe_use_string_handler(make_string_handler()) : null + } +} class user.Handler$stream { run: unknown } @@ -731,6 +790,16 @@ function user.run_tasks(tasks: user.Task[]) -> null throws int | s null : null } } +function user.run_tasks_alias(tasks: user.Task[]) -> null throws int | string { + { : null + let alias = tasks : user.Task[] + for task in alias + { : null + task.run() : null + } + null : null + } +} function user.make_mixed_tasks() -> user.Task[] throws never { { : user.Task[] [Task { name: "a", run: () -> null { ... } }, Task { name: "b", run: () -> null { ... } }] : user.Task[] @@ -740,11 +809,50 @@ lambda user.make_mixed_tasks { } lambda user.make_mixed_tasks { } +function user.run_local_tasks() -> null throws int | string { + { : null + let tasks = make_mixed_tasks() : user.Task[] + for task in tasks + { : null + task.run() : null + } + null : null + } +} +function user.run_inline_tasks() -> null throws never { + { : null + let tasks = [Task { name: "a", run: () -> null { ... } }, Task { name: "b", run: () -> null { ... } }] : user.Task[] + for task in tasks + { : null + task.run() : null + } + null : null + } +} +lambda user.run_inline_tasks { +} +lambda user.run_inline_tasks { +} function user.test_run_mixed_tasks() -> null throws int | string { { : null run_tasks(make_mixed_tasks()) : null } } +function user.test_run_tasks_alias() -> null throws int | string { + { : null + run_tasks_alias(make_mixed_tasks()) : null + } +} +function user.test_run_local_tasks() -> null throws never { + { : null + run_local_tasks() : null + } +} +function user.test_run_inline_tasks() -> null throws never { + { : null + run_inline_tasks() : null + } +} class user.Task$stream { name: null | string run: unknown diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index 8c6f45604e..bb35df7c94 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -42,6 +42,256 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ────╯ + [type] Warning: extraneous throws declaration: user.ErrorKind + ╭─[ enum_class_throws.baml:27:53 ] + │ + 27 │ function throw_enum_variant(x: int) -> string throws ErrorKind { + │ ─────┬──── + │ ╰────── extraneous throws declaration: user.ErrorKind + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `user.ErrorKind` is missing unknown + ╭─[ enum_class_throws.baml:27:53 ] + │ + 27 │ function throw_enum_variant(x: int) -> string throws ErrorKind { + │ ─────┬──── + │ ╰────── throws contract violation: `user.ErrorKind` is missing unknown + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: unresolved name: ErrorKind.NotFound + ╭─[ enum_class_throws.baml:28:22 ] + │ + 28 │ if (x == 0) { throw ErrorKind.NotFound } + │ ─────────┬───────── + │ ╰─────────── unresolved name: ErrorKind.NotFound + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: unresolved name: ErrorKind.Unauthorized + ╭─[ enum_class_throws.baml:29:22 ] + │ + 29 │ if (x == 1) { throw ErrorKind.Unauthorized } + │ ───────────┬─────────── + │ ╰───────────── unresolved name: ErrorKind.Unauthorized + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: unresolved name: ErrorKind.ValidationFailed + ╭─[ enum_class_throws.baml:35:28 ] + │ + 35 │ let f = () -> int { throw ErrorKind.ValidationFailed } + │ ─────────────┬───────────── + │ ╰─────────────── unresolved name: ErrorKind.ValidationFailed + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: unresolved name: ErrorKind.ValidationFailed + ╭─[ enum_class_throws.baml:36:3 ] + │ + 36 │ f() + │ ┬ + │ ╰── unresolved name: ErrorKind.ValidationFailed + │ + │ Note: Error code: E0001 +────╯ + + [type] Warning: extraneous throws declaration: user.ErrorKind + ╭─[ enum_class_throws.baml:60:55 ] + │ + 60 │ function throw_various_errors(x: int) -> string throws ErrorKind { + │ ─────┬──── + │ ╰────── extraneous throws declaration: user.ErrorKind + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `user.ErrorKind` is missing unknown + ╭─[ enum_class_throws.baml:60:55 ] + │ + 60 │ function throw_various_errors(x: int) -> string throws ErrorKind { + │ ─────┬──── + │ ╰────── throws contract violation: `user.ErrorKind` is missing unknown + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: unresolved name: ErrorKind.NotFound + ╭─[ enum_class_throws.baml:62:15 ] + │ + 62 │ 0 => throw ErrorKind.NotFound, + │ ─────────┬───────── + │ ╰─────────── unresolved name: ErrorKind.NotFound + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: unresolved name: ErrorKind.Unauthorized + ╭─[ enum_class_throws.baml:63:15 ] + │ + 63 │ 1 => throw ErrorKind.Unauthorized, + │ ───────────┬─────────── + │ ╰───────────── unresolved name: ErrorKind.Unauthorized + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: unresolved name: ErrorKind.ValidationFailed + ╭─[ enum_class_throws.baml:64:15 ] + │ + 64 │ 2 => throw ErrorKind.ValidationFailed, + │ ─────────────┬───────────── + │ ╰─────────────── unresolved name: ErrorKind.ValidationFailed + │ + │ Note: Error code: E0001 +────╯ + + [type] Warning: extraneous throws declaration: user.ErrorKind + ╭─[ enum_class_throws.baml:81:54 ] + │ + 81 │ function throw_enum_or_class(x: int) -> string throws ErrorKind | ApiError { + │ ──────────┬────────── + │ ╰──────────── extraneous throws declaration: user.ErrorKind + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `user.ErrorKind | user.ApiError` is missing unknown + ╭─[ enum_class_throws.baml:81:54 ] + │ + 81 │ function throw_enum_or_class(x: int) -> string throws ErrorKind | ApiError { + │ ──────────┬────────── + │ ╰──────────── throws contract violation: `user.ErrorKind | user.ApiError` is missing unknown + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: unresolved name: ErrorKind.RateLimited + ╭─[ enum_class_throws.baml:82:22 ] + │ + 82 │ if (x == 0) { throw ErrorKind.RateLimited } + │ ───────────┬────────── + │ ╰──────────── unresolved name: ErrorKind.RateLimited + │ + │ Note: Error code: E0001 +────╯ + + [type] Warning: extraneous throws declaration: user.ErrorKind + ╭─[ enum_class_throws.baml:88:50 ] + │ + 88 │ function throw_any_error(x: int) -> string throws string | ErrorKind | ApiError { + │ ───────────────┬────────────── + │ ╰──────────────── extraneous throws declaration: user.ErrorKind + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: throws contract violation: `string | user.ErrorKind | user.ApiError` is missing unknown + ╭─[ enum_class_throws.baml:88:50 ] + │ + 88 │ function throw_any_error(x: int) -> string throws string | ErrorKind | ApiError { + │ ───────────────┬────────────── + │ ╰──────────────── throws contract violation: `string | user.ErrorKind | user.ApiError` is missing unknown + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: unresolved name: ErrorKind.NotFound + ╭─[ enum_class_throws.baml:91:15 ] + │ + 91 │ 1 => throw ErrorKind.NotFound, + │ ─────────┬───────── + │ ╰─────────── unresolved name: ErrorKind.NotFound + │ + │ Note: Error code: E0001 +────╯ + + [type] Error: unresolved name: ErrorKind.Unauthorized + ╭─[ enum_class_throws.baml:105:24 ] + │ + 105 │ apply_may_throw_enum(() -> int { throw ErrorKind.Unauthorized }) + │ ─────────────────────┬──────────────────── + │ ╰────────────────────── unresolved name: ErrorKind.Unauthorized + │ + │ Note: Error code: E0001 +─────╯ + + [type] Error: unresolved name: ErrorKind.Unauthorized + ╭─[ enum_class_throws.baml:105:41 ] + │ + 105 │ apply_may_throw_enum(() -> int { throw ErrorKind.Unauthorized }) + │ ───────────┬─────────── + │ ╰───────────── unresolved name: ErrorKind.Unauthorized + │ + │ Note: Error code: E0001 +─────╯ + + [type] Error: unresolved name: ErrorKind.NotFound + ╭─[ enum_class_throws.baml:123:22 ] + │ + 123 │ apply_generic_enum(() -> int { throw ErrorKind.NotFound }) + │ ───────────────────┬────────────────── + │ ╰──────────────────── unresolved name: ErrorKind.NotFound + │ + │ Note: Error code: E0001 +─────╯ + + [type] Error: unresolved name: ErrorKind.NotFound + ╭─[ enum_class_throws.baml:123:39 ] + │ + 123 │ apply_generic_enum(() -> int { throw ErrorKind.NotFound }) + │ ─────────┬───────── + │ ╰─────────── unresolved name: ErrorKind.NotFound + │ + │ Note: Error code: E0001 +─────╯ + + [type] Error: unresolved name: ErrorKind.NotFound + ╭─[ enum_class_throws.baml:174:3 ] + │ + 174 │ EnumThrowingHandler { run: () -> null { throw ErrorKind.NotFound } } + │ ──────────────────────────────────┬───────────────────────────────── + │ ╰─────────────────────────────────── unresolved name: ErrorKind.NotFound + │ + │ Note: Error code: E0001 +─────╯ + + [type] Error: unresolved name: ErrorKind.NotFound + ╭─[ enum_class_throws.baml:174:48 ] + │ + 174 │ EnumThrowingHandler { run: () -> null { throw ErrorKind.NotFound } } + │ ─────────┬───────── + │ ╰─────────── unresolved name: ErrorKind.NotFound + │ + │ Note: Error code: E0001 +─────╯ + + [type] Error: unresolved name: ErrorKind.ValidationFailed + ╭─[ enum_class_throws.baml:194:20 ] + │ + 194 │ use_enum_thrower(() -> int { throw ErrorKind.ValidationFailed }) + │ ───────────────────────┬────────────────────── + │ ╰──────────────────────── unresolved name: ErrorKind.ValidationFailed + │ + │ Note: Error code: E0001 +─────╯ + + [type] Error: unresolved name: ErrorKind.ValidationFailed + ╭─[ enum_class_throws.baml:194:37 ] + │ + 194 │ use_enum_thrower(() -> int { throw ErrorKind.ValidationFailed }) + │ ─────────────┬───────────── + │ ╰─────────────── unresolved name: ErrorKind.ValidationFailed + │ + │ Note: Error code: E0001 +─────╯ + [validation] Error: Name `Wrapper` defined 2 times as: type, class ╭─[ fn_type_alias_throws.baml:20:6 ] │ @@ -79,6 +329,16 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ───╯ + [type] Error: `(() -> null throws string)?` is not a function — it cannot be called + ╭─[ generic_stored_callback.baml:27:1 ] + │ + 27 │ h?.run() + │ ─────┬──── + │ ╰────── `(() -> null throws string)?` is not a function — it cannot be called + │ + │ Note: Error code: E0001 +────╯ + [type] Error: type mismatch: expected user.Box, got user.Box ╭─[ generic_subtype_rejection.baml:14:19 ] │ diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 298311c96d..9769e235e1 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -437,6 +437,24 @@ function user.map_it(x: int, f: (int) -> string) -> string { return } +function user.maybe_use_string_handler(h: Handler?) -> null { + load_var h + load_const null + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_var h + load_field .0 + call_indirect + pop 1 + + L1: + load_const null + return +} + function user.optional_call_caught(cb: (() -> int throws string)?) -> int? { load_var cb load_const null @@ -479,6 +497,80 @@ function user.optional_call_rethrows(cb: (() -> int throws string)?) -> int? { return } +function user.run_inline_tasks() -> null { + alloc_instance Task + load_const "a" + init_field .name + make_closure ., 0 + init_field .run + alloc_instance Task + load_const "b" + init_field .name + make_closure ., 0 + init_field .run + alloc_array 2 + store_var _6 + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var _6 + call baml.Array.length + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_const null + return + + L2: + load_var _6 + load_var __for_idx + load_array_element + load_field .run + call_indirect + pop 1 + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 +} + +function user.run_local_tasks() -> null { + call user.make_mixed_tasks + store_var _2 + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var _2 + call baml.Array.length + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_const null + return + + L2: + load_var _2 + load_var __for_idx + load_array_element + load_field .run + call_indirect + pop 1 + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 +} + function user.run_pure(f: () -> int) -> int { load_var f call_indirect @@ -515,6 +607,38 @@ function user.run_tasks(tasks: Task[]) -> null { jump L0 } +function user.run_tasks_alias(tasks: Task[]) -> null { + load_var tasks + store_var _3 + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var _3 + call baml.Array.length + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_const null + return + + L2: + load_var _3 + load_var __for_idx + load_array_element + load_field .run + call_indirect + pop 1 + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 +} + function user.run_throwing(f: () -> int) -> int { load_var f call_indirect @@ -797,6 +921,11 @@ function user.test_lambda_throws_enum() -> int { return } +function user.test_local_string_handler() -> null { + call user.use_local_string_handler + return +} + function user.test_many_args_pure() -> string { load_const 1 load_const "hello" @@ -879,6 +1008,12 @@ function user.test_optional_call_with_throwing_callback() -> int? { return } +function user.test_optional_string_handler() -> null { + call user.make_string_handler + call user.maybe_use_string_handler + return +} + function user.test_pure_lambda() -> int { make_closure ., 0 call_indirect @@ -903,6 +1038,16 @@ function user.test_rethrows_enum() -> int { return } +function user.test_run_inline_tasks() -> null { + call user.run_inline_tasks + return +} + +function user.test_run_local_tasks() -> null { + call user.run_local_tasks + return +} + function user.test_run_mixed_tasks() -> null { call user.make_mixed_tasks call user.run_tasks @@ -915,6 +1060,12 @@ function user.test_run_pure() -> int { return } +function user.test_run_tasks_alias() -> null { + call user.make_mixed_tasks + call user.run_tasks_alias + return +} + function user.test_run_throwing() -> int { make_closure ., 0 call user.run_throwing @@ -963,6 +1114,12 @@ function user.test_string_handler() -> null { return } +function user.test_string_handler_alias() -> null { + call user.make_string_handler + call user.use_string_handler_alias + return +} + function user.test_throwing_int() -> string { make_closure ., 0 call_indirect @@ -1099,8 +1256,7 @@ function user.throw_any_error(x: int) -> string { throw L4: - load_const user.ErrorKind.NotFound - alloc_variant user.ErrorKind + load_const null throw L5: @@ -1155,8 +1311,7 @@ function user.throw_enum_or_class(x: int) -> string { throw L3: - load_const user.ErrorKind.RateLimited - alloc_variant user.ErrorKind + load_const null throw } @@ -1179,13 +1334,11 @@ function user.throw_enum_variant(x: int) -> string { return L2: - load_const user.ErrorKind.Unauthorized - alloc_variant user.ErrorKind + load_const null throw L3: - load_const user.ErrorKind.NotFound - alloc_variant user.ErrorKind + load_const null throw } @@ -1255,18 +1408,15 @@ function user.throw_various_errors(x: int) -> string { return L3: - load_const user.ErrorKind.ValidationFailed - alloc_variant user.ErrorKind + load_const null throw L4: - load_const user.ErrorKind.Unauthorized - alloc_variant user.ErrorKind + load_const null throw L5: - load_const user.ErrorKind.NotFound - alloc_variant user.ErrorKind + load_const null throw } @@ -1302,6 +1452,13 @@ function user.use_int_handler(h: Handler) -> null { return } +function user.use_local_string_handler() -> null { + call user.make_string_handler + load_field .run + call_indirect + return +} + function user.use_mixed_thrower(f: () -> int throws ErrorKind | ApiError | string) -> int { load_var f call_indirect @@ -1314,3 +1471,10 @@ function user.use_string_handler(h: Handler) -> null { call_indirect return } + +function user.use_string_handler_alias(h: Handler) -> null { + load_var h + load_field .run + call_indirect + return +} From 7be9e5745a1a47a57952f48d9c9345caea04ad86 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Fri, 10 Apr 2026 03:04:12 -0500 Subject: [PATCH 24/26] Refine package resolution context and throws recovery --- .../crates/baml_compiler2_tir/src/builder.rs | 45 +- .../baml_compiler2_tir/src/inference.rs | 4 +- .../src/package_interface.rs | 489 ++++++++++-------- .../crates/baml_compiler2_tir/src/resolve.rs | 6 +- .../baml_compiler2_tir/src/throw_inference.rs | 184 ++++--- 5 files changed, 375 insertions(+), 353 deletions(-) diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index 4857289498..0ed568e303 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -2825,7 +2825,7 @@ impl<'db> TypeInferenceBuilder<'db> { ); // Don't report "unresolved name" for dependency package names — // they'll be resolved by the parent FieldAccess expression. - let is_dep_package = self.res_ctx.dep_interfaces.iter().any(|(n, _)| n == name); + let is_dep_package = self.res_ctx.deps.iter().any(|dep| dep.name == *name); if matches!(ty, Ty::Unknown { .. }) && !self.locals.contains_key(name) && self @@ -2985,28 +2985,13 @@ impl<'db> TypeInferenceBuilder<'db> { full_path.extend_from_slice(path); if let Some((_source, function)) = self.res_ctx.lookup_function(db, &full_path, &[]) { - let item = path.last().expect("non-empty path"); - if let Some(Definition::Function(func_loc)) = self - .res_ctx - .lookup_value_definition_in_package(db, pkg_name, &path[..path.len() - 1], item) - { - self.resolutions.insert( - expr_id, - crate::inference::MemberResolution::Free { func_loc }, - ); - } - return Some(Ty::Function { - params: function - .params - .into_iter() - .map(|(n, ty)| (Some(n), ty)) - .collect(), - ret: Box::new(function.return_type), - throws: Box::new(function.outward_throws.unwrap_or(Ty::Never { - attr: TyAttr::default(), - })), - attr: TyAttr::default(), - }); + self.resolutions.insert( + expr_id, + crate::inference::MemberResolution::Free { + func_loc: function.func_loc, + }, + ); + return Some(function.as_ty()); } self.res_ctx @@ -3523,19 +3508,7 @@ impl<'db> TypeInferenceBuilder<'db> { let resolved = self .res_ctx .lookup_class_method(db, class_name, method_name)?; - let ty = Ty::Function { - params: resolved - .function - .params - .into_iter() - .map(|(n, ty)| (Some(n), ty)) - .collect(), - ret: Box::new(resolved.function.return_type), - throws: Box::new(resolved.function.outward_throws.unwrap_or(Ty::Never { - attr: TyAttr::default(), - })), - attr: TyAttr::default(), - }; + let ty = resolved.function.as_ty(); Some((ty, class_loc, func_loc)) } } diff --git a/baml_language/crates/baml_compiler2_tir/src/inference.rs b/baml_language/crates/baml_compiler2_tir/src/inference.rs index 0facee0462..8cf9f4785e 100644 --- a/baml_language/crates/baml_compiler2_tir/src/inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/inference.rs @@ -291,8 +291,8 @@ pub fn infer_scope_types<'db>( let mut aliases = collect_type_aliases(db, pkg_items); // Also collect type aliases from dependency packages so that e.g. // `testing.TestRunner` can be resolved during subtype checking. - for (_dep_name, dep_iface) in &res_ctx.dep_interfaces { - for types_in_ns in dep_iface.types.values() { + for dep in &res_ctx.deps { + for types_in_ns in dep.interface.types.values() { for exported in types_in_ns.values() { if let crate::package_interface::ExportedType::TypeAlias { qtn, resolved } = exported diff --git a/baml_language/crates/baml_compiler2_tir/src/package_interface.rs b/baml_language/crates/baml_compiler2_tir/src/package_interface.rs index defed4c390..0de7b2171b 100644 --- a/baml_language/crates/baml_compiler2_tir/src/package_interface.rs +++ b/baml_language/crates/baml_compiler2_tir/src/package_interface.rs @@ -1,7 +1,7 @@ //! Package interface types and resolution context. //! //! `PackageInterface` is a fully-resolved typed summary of everything a package -//! exports — classes, enums, type aliases, functions, and throw sets. +//! exports — classes, enums, type aliases, functions, and callable identities. //! Dependent packages consume this instead of reaching into raw `ItemTree` / //! `TypeExpr` data. //! @@ -80,7 +80,8 @@ pub enum ResolvedSource { } /// Common output for resolved function signatures. -pub struct ResolvedFunction { +pub struct ResolvedFunction<'db> { + pub func_loc: baml_compiler2_hir::loc::FunctionLoc<'db>, pub name: Name, pub params: Vec<(Name, Ty)>, pub return_type: Ty, @@ -90,12 +91,29 @@ pub struct ResolvedFunction { } /// Common output for resolved method lookups (includes class context). -pub struct ResolvedMethod { - pub function: ResolvedFunction, +pub struct ResolvedMethod<'db> { + pub function: ResolvedFunction<'db>, pub class_name: Name, pub class_generic_params: Vec, } +impl ResolvedFunction<'_> { + pub fn as_ty(&self) -> Ty { + Ty::Function { + params: self + .params + .iter() + .map(|(name, ty)| (Some(name.clone()), ty.clone())) + .collect(), + ret: Box::new(self.return_type.clone()), + throws: Box::new(self.outward_throws.clone().unwrap_or(Ty::Never { + attr: TyAttr::default(), + })), + attr: TyAttr::default(), + } + } +} + /// Specialized resolution result for builtin class members. /// /// Builtin methods can be effect-polymorphic over direct callback parameters, so @@ -110,12 +128,17 @@ pub enum ResolvedBuiltinMember<'db> { Field(Ty), } +pub struct ResolvedDependency<'db> { + pub name: Name, + pub package_id: PackageId<'db>, + pub interface: &'db PackageInterface, +} + /// Bundles a package's own items with its dependencies' pre-resolved interfaces. /// All cross-package lookups go through this context's methods. pub struct PackageResolutionContext<'db> { pub own_items: &'db PackageItems<'db>, - pub dep_interfaces: Vec<(Name, &'db PackageInterface)>, - pub dep_packages: Vec<(Name, PackageId<'db>)>, + pub deps: Vec>, pub own_package_name: Name, } @@ -123,13 +146,12 @@ impl PartialEq for PackageResolutionContext<'_> { fn eq(&self, other: &Self) -> bool { std::ptr::eq(self.own_items, other.own_items) && self.own_package_name == other.own_package_name - && self.dep_interfaces.len() == other.dep_interfaces.len() - && self.dep_packages == other.dep_packages - && self - .dep_interfaces - .iter() - .zip(other.dep_interfaces.iter()) - .all(|((n1, i1), (n2, i2))| n1 == n2 && std::ptr::eq(*i1, *i2)) + && self.deps.len() == other.deps.len() + && self.deps.iter().zip(other.deps.iter()).all(|(d1, d2)| { + d1.name == d2.name + && d1.package_id == d2.package_id + && std::ptr::eq(d1.interface, d2.interface) + }) } } @@ -498,6 +520,150 @@ fn outward_throws_from_key(throw_sets: &FunctionThrowSets, key: &Name) -> Option } } +fn builtin_runtime_class_ty( + package: &Name, + namespace_path: &[Name], + class_name: &Name, + class_generic_params: &[Name], + type_args: &[Ty], +) -> Ty { + if type_args.is_empty() { + Ty::Class( + QualifiedTypeName::new_with_generic_params( + package.clone(), + namespace_path.to_vec(), + class_name.clone(), + class_generic_params.to_vec(), + ) + .into(), + TyAttr::default(), + ) + } else if type_args.len() == 1 && class_name.as_str() == "Array" { + Ty::List(Box::new(type_args[0].clone()), TyAttr::default()) + } else if type_args.len() == 2 && class_name.as_str() == "Map" { + Ty::Map( + Box::new(type_args[0].clone()), + Box::new(type_args[1].clone()), + TyAttr::default(), + ) + } else { + Ty::Class( + QualifiedTypeName::new(package.clone(), namespace_path.to_vec(), class_name.clone()) + .into(), + TyAttr::default(), + ) + } +} + +fn synthetic_effect_throws_ty(effect_vars: &[Name]) -> Ty { + match effect_vars.len() { + 0 => Ty::Never { + attr: TyAttr::default(), + }, + 1 => Ty::TypeVar(effect_vars[0].clone(), TyAttr::default()), + _ => Ty::Union( + effect_vars + .iter() + .map(|name| Ty::TypeVar(name.clone(), TyAttr::default())) + .collect(), + TyAttr::default(), + ), + } +} + +#[allow(clippy::too_many_arguments)] +fn lower_builtin_method_ty<'db>( + db: &'db dyn crate::Db, + baml_items: &'db PackageItems<'db>, + stub_ns: &[Name], + package: &Name, + namespace_path: &[Name], + class_name: &Name, + class_generic_params: &[Name], + method_generic_params: &[Name], + sig: &baml_compiler2_hir::signature::FunctionSignature, + type_args: &[Ty], + bindings: &FxHashMap, +) -> Ty { + let builtin_class_ty = builtin_runtime_class_ty( + package, + namespace_path, + class_name, + class_generic_params, + type_args, + ); + let mut diags = Vec::new(); + let mut synthetic_effect_vars: Vec = Vec::new(); + let params: Vec<(Option, Ty)> = sig + .params + .iter() + .map(|(n, te)| { + let ty = if n.as_str() == "self" + && matches!(te, baml_compiler2_ast::TypeExpr::Unknown { .. }) + { + builtin_class_ty.clone() + } else if matches!(te, baml_compiler2_ast::TypeExpr::Function { .. }) { + let lowered = crate::lower_type_expr::lower_type_expr_with_fn_context( + db, + te, + baml_items, + stub_ns, + method_generic_params, + &mut diags, + &crate::lower_type_expr::FnTypeLoweringContext::DirectParamRoot { + param_name: n.clone(), + }, + &mut synthetic_effect_vars, + ); + crate::generics::substitute_ty(&lowered, bindings) + } else { + crate::generics::lower_type_expr_with_generics( + db, te, baml_items, stub_ns, bindings, &mut diags, + ) + }; + (Some(n.clone()), ty) + }) + .collect(); + + let ret = sig + .return_type + .as_ref() + .map(|te| { + crate::generics::lower_type_expr_with_generics( + db, te, baml_items, stub_ns, bindings, &mut diags, + ) + }) + .unwrap_or(Ty::Void { + attr: TyAttr::default(), + }); + + Ty::Function { + params, + ret: Box::new(ret), + throws: Box::new(synthetic_effect_throws_ty(&synthetic_effect_vars)), + attr: TyAttr::default(), + } +} + +fn lower_builtin_field_ty<'db>( + db: &'db dyn crate::Db, + baml_items: &'db PackageItems<'db>, + stub_ns: &[Name], + type_expr: Option<&baml_compiler2_ast::SpannedTypeExpr>, + bindings: &FxHashMap, +) -> Ty { + let mut diags = Vec::new(); + type_expr + .map(|te| { + crate::generics::lower_type_expr_with_generics( + db, &te.expr, baml_items, stub_ns, bindings, &mut diags, + ) + }) + .unwrap_or(Ty::Unknown { + attr: TyAttr::default(), + }) +} + // ── package_resolution_context Salsa query ───────────────────────────────── #[salsa::tracked(returns(ref))] @@ -507,22 +673,17 @@ pub fn package_resolution_context<'db>( ) -> PackageResolutionContext<'db> { let own_items = package_items(db, pkg_id); let deps = package_dependencies(db, pkg_id); - let dep_packages: Vec<(Name, PackageId<'db>)> = deps + let deps: Vec> = deps .iter() - .map(|dep_id| (dep_id.name(db), *dep_id)) - .collect(); - let dep_interfaces: Vec<(Name, &PackageInterface)> = deps - .iter() - .map(|dep_id| { - let name = dep_id.name(db); - let iface = package_interface(db, *dep_id); - (name, iface) + .map(|dep_id| ResolvedDependency { + name: dep_id.name(db), + package_id: *dep_id, + interface: package_interface(db, *dep_id), }) .collect(); PackageResolutionContext { own_items, - dep_interfaces, - dep_packages, + deps, own_package_name: pkg_id.name(db), } } @@ -530,12 +691,6 @@ pub fn package_resolution_context<'db>( // ── PackageResolutionContext lookup methods ───────────────────────────────── impl<'db> PackageResolutionContext<'db> { - fn dep_package_id(&self, pkg_name: &Name) -> Option> { - self.dep_packages - .iter() - .find_map(|(dep_name, dep_id)| (dep_name == pkg_name).then_some(*dep_id)) - } - /// Get `PackageItems` for an accessible package (own or declared dependency). /// /// Returns `Some` for the own package and any declared dependency, @@ -547,13 +702,12 @@ impl<'db> PackageResolutionContext<'db> { ) -> Option<&'db PackageItems<'db>> { if pkg_name.as_str() == self.own_package_name.as_str() { Some(self.own_items) - } else if self - .dep_interfaces + } else if let Some(dep) = self + .deps .iter() - .any(|(n, _)| n.as_str() == pkg_name.as_str()) + .find(|dep| dep.name.as_str() == pkg_name.as_str()) { - let pkg_id = PackageId::new(db, pkg_name.clone()); - Some(package_items(db, pkg_id)) + Some(package_items(db, dep.package_id)) } else { None } @@ -564,7 +718,7 @@ impl<'db> PackageResolutionContext<'db> { db: &'db dyn crate::Db, path: &[Name], ns_context: &[Name], - ) -> Option<(ResolvedSource, ResolvedFunction)> { + ) -> Option<(ResolvedSource, ResolvedFunction<'db>)> { let item = path.last()?; if !ns_context.is_empty() { let ns: Vec<_> = ns_context @@ -590,13 +744,12 @@ impl<'db> PackageResolutionContext<'db> { return Some((ResolvedSource::Item, function)); } } - for (dep_name, _dep_iface) in &self.dep_interfaces { - if &path[0] == dep_name { - if let Some(function) = - self.lookup_dep_function(db, dep_name, &path[1..path.len() - 1], item) - { - return Some((ResolvedSource::Builtin, function)); - } + for dep in &self.deps { + if path[0] == dep.name + && let Some(function) = + Self::lookup_dep_function(db, dep, &path[1..path.len() - 1], item) + { + return Some((ResolvedSource::Builtin, function)); } } } @@ -608,12 +761,12 @@ impl<'db> PackageResolutionContext<'db> { db: &'db dyn crate::Db, namespace: &[Name], item: &Name, - ) -> Option<(ResolvedSource, ResolvedFunction)> { + ) -> Option<(ResolvedSource, ResolvedFunction<'db>)> { if let Some(function) = self.lookup_own_function(db, namespace, item) { return Some((ResolvedSource::Item, function)); } - for (dep_name, _dep_iface) in &self.dep_interfaces { - if let Some(function) = self.lookup_dep_function(db, dep_name, namespace, item) { + for dep in &self.deps { + if let Some(function) = Self::lookup_dep_function(db, dep, namespace, item) { return Some((ResolvedSource::Builtin, function)); } } @@ -625,7 +778,7 @@ impl<'db> PackageResolutionContext<'db> { db: &'db dyn crate::Db, namespace: &[Name], item: &Name, - ) -> Option { + ) -> Option> { let Definition::Function(func_loc) = self.own_items.lookup_value(namespace, item)? else { return None; }; @@ -678,6 +831,7 @@ impl<'db> PackageResolutionContext<'db> { }; drop(diags); Some(ResolvedFunction { + func_loc, name: func_data.name.clone(), params, return_type, @@ -688,24 +842,27 @@ impl<'db> PackageResolutionContext<'db> { } fn lookup_dep_function( - &self, db: &'db dyn crate::Db, - pkg_name: &Name, + dep: &ResolvedDependency<'db>, namespace: &[Name], item: &Name, - ) -> Option { - let dep_iface = self - .dep_interfaces - .iter() - .find_map(|(dep_name, dep_iface)| (dep_name == pkg_name).then_some(*dep_iface))?; - let exported = dep_iface.lookup_function(namespace, item)?; - let dep_pkg_id = self.dep_package_id(pkg_name)?; + ) -> Option> { + let exported = dep.interface.lookup_function(namespace, item)?; + let Definition::Function(func_loc) = + package_items(db, dep.package_id).lookup_value(namespace, item)? + else { + return None; + }; Some(ResolvedFunction { + func_loc, name: exported.name.clone(), params: exported.params.clone(), return_type: exported.return_type.clone(), outward_throws: exported.declared_throws.clone().or_else(|| { - outward_throws_from_key(function_throw_sets(db, dep_pkg_id), &exported.callable_key) + outward_throws_from_key( + function_throw_sets(db, dep.package_id), + &exported.callable_key, + ) }), generic_params: exported.generic_params.clone(), builtin_kind: exported.builtin_kind, @@ -750,9 +907,11 @@ impl<'db> PackageResolutionContext<'db> { return Some((ResolvedSource::Item, ty)); } } - for (dep_name, dep_iface) in &self.dep_interfaces { - if &path[0] == dep_name { - if let Some(exported) = dep_iface.lookup_type(&path[1..path.len() - 1], item) { + for dep in &self.deps { + if path[0] == dep.name { + if let Some(exported) = + dep.interface.lookup_type(&path[1..path.len() - 1], item) + { return Some((ResolvedSource::Builtin, exported.to_ty())); } } @@ -772,8 +931,8 @@ impl<'db> PackageResolutionContext<'db> { let ty = def_to_ty(db, def); return Some((ResolvedSource::Item, ty)); } - for (_dep_name, dep_iface) in &self.dep_interfaces { - if let Some(exported) = dep_iface.lookup_type(namespace, item) { + for dep in &self.deps { + if let Some(exported) = dep.interface.lookup_type(namespace, item) { return Some((ResolvedSource::Builtin, exported.to_ty())); } } @@ -813,11 +972,11 @@ impl<'db> PackageResolutionContext<'db> { } } // dep-prefixed search (parity with resolve_type) - for (dep_name, _dep_iface) in &self.dep_interfaces { - if &path[0] == dep_name { + for dep in &self.deps { + if path[0] == dep.name { if let Some(def) = self.lookup_value_definition_in_package( db, - dep_name, + &dep.name, &path[1..path.len() - 1], item, ) { @@ -912,13 +1071,16 @@ impl<'db> PackageResolutionContext<'db> { let enum_data = &item_tree[enum_loc.id(db)]; enum_data.variants.iter().map(|v| v.name.clone()).collect() } else { - self.dep_interfaces + self.deps .iter() - .find_map(|(dep_name, dep_iface)| { - if dep_name != enum_name.package() { + .find_map(|dep| { + if &dep.name != enum_name.package() { return None; } - match dep_iface.lookup_type(enum_name.namespace(), enum_name.name())? { + match dep + .interface + .lookup_type(enum_name.namespace(), enum_name.name())? + { ExportedType::Enum { variants, .. } => Some(variants.clone()), _ => None, } @@ -940,15 +1102,17 @@ impl<'db> PackageResolutionContext<'db> { let raw_fields = self.lookup_own_class_fields(db, class_name.qtn()); self.apply_class_substitution(db, class_name, raw_fields) } else { - for (dep_name, dep_iface) in &self.dep_interfaces { - if dep_name != class_pkg { + for dep in &self.deps { + if &dep.name != class_pkg { continue; } if let Some(ExportedType::Class { fields, generic_params, .. - }) = dep_iface.lookup_type(class_name.namespace(), class_name.name()) + }) = dep + .interface + .lookup_type(class_name.namespace(), class_name.name()) { if class_name.type_args().is_empty() || generic_params.is_empty() { return fields.clone(); @@ -1002,7 +1166,7 @@ impl<'db> PackageResolutionContext<'db> { db: &'db dyn crate::Db, class_name: &NominalTypeRef, method_name: &Name, - ) -> Option { + ) -> Option> { let class_pkg = class_name.package(); if class_pkg.as_str() == self.own_package_name.as_str() { self.lookup_own_class_method(db, class_name.qtn(), method_name) @@ -1011,28 +1175,33 @@ impl<'db> PackageResolutionContext<'db> { rm }) } else { - for (dep_name, dep_iface) in &self.dep_interfaces { - if dep_name != class_pkg { + for dep in &self.deps { + if &dep.name != class_pkg { continue; } - let Some(dep_pkg_id) = self.dep_package_id(dep_name) else { - continue; - }; if let Some(ExportedType::Class { methods, generic_params, .. - }) = dep_iface.lookup_type(class_name.namespace(), class_name.name()) + }) = dep + .interface + .lookup_type(class_name.namespace(), class_name.name()) { if let Some(method) = methods.iter().find(|m| &m.name == method_name) { + let Some((_, func_loc)) = + self.lookup_class_method_locs(db, class_name, method_name) + else { + continue; + }; let mut resolved_method = ResolvedMethod { function: ResolvedFunction { + func_loc, name: method.name.clone(), params: method.params.clone(), return_type: method.return_type.clone(), outward_throws: method.declared_throws.clone().or_else(|| { outward_throws_from_key( - function_throw_sets(db, dep_pkg_id), + function_throw_sets(db, dep.package_id), &method.callable_key, ) }), @@ -1142,134 +1311,26 @@ impl<'db> PackageResolutionContext<'db> { let func_loc = baml_compiler2_hir::loc::FunctionLoc::new(db, file, method_id); let sig = baml_compiler2_hir::signature::function_signature(db, func_loc); - let mut diags = Vec::new(); - - let builtin_class_ty = if type_args.is_empty() { - Ty::Class( - QualifiedTypeName::new_with_generic_params( - stub_pkg.package.clone(), - stub_pkg.namespace_path.clone(), - class_data.name.clone(), - class_data.generic_params.clone(), - ) - .into(), - TyAttr::default(), - ) - } else if type_args.len() == 1 { - match class_data.name.as_str() { - "Array" => Ty::List(Box::new(type_args[0].clone()), TyAttr::default()), - _ => Ty::Class( - QualifiedTypeName::new( - stub_pkg.package.clone(), - stub_pkg.namespace_path.clone(), - class_data.name.clone(), - ) - .into(), - TyAttr::default(), - ), - } - } else if type_args.len() == 2 { - match class_data.name.as_str() { - "Map" => Ty::Map( - Box::new(type_args[0].clone()), - Box::new(type_args[1].clone()), - TyAttr::default(), - ), - _ => Ty::Class( - QualifiedTypeName::new( - stub_pkg.package.clone(), - stub_pkg.namespace_path.clone(), - class_data.name.clone(), - ) - .into(), - TyAttr::default(), - ), - } - } else { - Ty::Class( - QualifiedTypeName::new( - stub_pkg.package.clone(), - stub_pkg.namespace_path.clone(), - class_data.name.clone(), - ) - .into(), - TyAttr::default(), - ) - }; - let method_generic_params: Vec = class_data .generic_params .iter() .chain(method_data.generic_params.iter()) .cloned() .collect(); - - let mut synthetic_effect_vars: Vec = Vec::new(); - let params: Vec<(Option, Ty)> = sig - .params - .iter() - .map(|(n, te)| { - let ty = if n.as_str() == "self" - && matches!(te, baml_compiler2_ast::TypeExpr::Unknown { .. }) - { - builtin_class_ty.clone() - } else if matches!(te, baml_compiler2_ast::TypeExpr::Function { .. }) { - let lowered = crate::lower_type_expr::lower_type_expr_with_fn_context( - db, - te, - baml_items, - stub_ns, - &method_generic_params, - &mut diags, - &crate::lower_type_expr::FnTypeLoweringContext::DirectParamRoot { - param_name: n.clone(), - }, - &mut synthetic_effect_vars, - ); - crate::generics::substitute_ty(&lowered, &bindings) - } else { - crate::generics::lower_type_expr_with_generics( - db, te, baml_items, stub_ns, &bindings, &mut diags, - ) - }; - (Some(n.clone()), ty) - }) - .collect(); - - let ret = sig - .return_type - .as_ref() - .map(|te| { - crate::generics::lower_type_expr_with_generics( - db, te, baml_items, stub_ns, &bindings, &mut diags, - ) - }) - .unwrap_or(Ty::Void { - attr: TyAttr::default(), - }); - - let throws_ty = match synthetic_effect_vars.len() { - 0 => Ty::Never { - attr: TyAttr::default(), - }, - 1 => Ty::TypeVar(synthetic_effect_vars[0].clone(), TyAttr::default()), - _ => Ty::Union( - synthetic_effect_vars - .iter() - .map(|v| Ty::TypeVar(v.clone(), TyAttr::default())) - .collect(), - TyAttr::default(), - ), - }; - - drop(diags); return Some(ResolvedBuiltinMember::Method { - ty: Ty::Function { - params, - ret: Box::new(ret), - throws: Box::new(throws_ty), - attr: TyAttr::default(), - }, + ty: lower_builtin_method_ty( + db, + baml_items, + stub_ns, + &stub_pkg.package, + &stub_pkg.namespace_path, + &class_data.name, + &class_data.generic_params, + &method_generic_params, + sig.as_ref(), + type_args, + &bindings, + ), class_loc, func_loc, }); @@ -1280,20 +1341,13 @@ impl<'db> PackageResolutionContext<'db> { continue; } - let mut diags = Vec::new(); - let field_ty = field - .type_expr - .as_ref() - .map(|te| { - crate::generics::lower_type_expr_with_generics( - db, &te.expr, baml_items, stub_ns, &bindings, &mut diags, - ) - }) - .unwrap_or(Ty::Unknown { - attr: TyAttr::default(), - }); - drop(diags); - return Some(ResolvedBuiltinMember::Field(field_ty)); + return Some(ResolvedBuiltinMember::Field(lower_builtin_field_ty( + db, + baml_items, + stub_ns, + field.type_expr.as_ref(), + &bindings, + ))); } None @@ -1303,7 +1357,7 @@ impl<'db> PackageResolutionContext<'db> { /// Uses `generic_params` from the own-package item tree. fn apply_method_substitution( class_name: &NominalTypeRef, - resolved_method: &mut ResolvedMethod, + resolved_method: &mut ResolvedMethod<'_>, ) { if class_name.type_args().is_empty() { return; @@ -1377,7 +1431,7 @@ impl<'db> PackageResolutionContext<'db> { db: &'db dyn crate::Db, class_name: &QualifiedTypeName, method_name: &Name, - ) -> Option { + ) -> Option> { let def = self .own_items .lookup_type(class_name.namespace(), class_name.name())?; @@ -1440,6 +1494,7 @@ impl<'db> PackageResolutionContext<'db> { return Some(ResolvedMethod { function: ResolvedFunction { + func_loc: method_loc, name: method_data.name.clone(), params, return_type, diff --git a/baml_language/crates/baml_compiler2_tir/src/resolve.rs b/baml_language/crates/baml_compiler2_tir/src/resolve.rs index 82818613ab..1930e3cb95 100644 --- a/baml_language/crates/baml_compiler2_tir/src/resolve.rs +++ b/baml_language/crates/baml_compiler2_tir/src/resolve.rs @@ -126,16 +126,16 @@ pub fn resolve_name_at_in_scope<'db>( ) { return ResolvedName::Item(def); } - for (dep_name, _) in &res_ctx.dep_interfaces { + for dep in &res_ctx.deps { if let Some(def) = res_ctx .lookup_type_definition_in_package( db, - dep_name, + &dep.name, &pkg_info.namespace_path, name, ) .or_else(|| { - res_ctx.lookup_type_definition_in_package(db, dep_name, &[], name) + res_ctx.lookup_type_definition_in_package(db, &dep.name, &[], name) }) { return ResolvedName::Builtin(def); diff --git a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs index 1a578c1cc9..7b7b49a68d 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs @@ -70,8 +70,8 @@ pub fn function_throw_sets<'db>( let res_ctx = crate::package_interface::package_resolution_context(db, package_id); let mut aliases = collect_type_aliases(db, pkg_items); // Merge dependency type aliases for cross-package field type resolution - for (_dep_name, dep_iface) in &res_ctx.dep_interfaces { - for types_in_ns in dep_iface.types.values() { + for dep in &res_ctx.deps { + for types_in_ns in dep.interface.types.values() { for exported in types_in_ns.values() { if let crate::package_interface::ExportedType::TypeAlias { qtn, resolved } = exported @@ -81,8 +81,7 @@ pub fn function_throw_sets<'db>( } } } - let dep_packages = &res_ctx.dep_packages; - let dep_interfaces = &res_ctx.dep_interfaces; + let deps = &res_ctx.deps; let mut graph: crate::analysis::AnalysisGraph = crate::analysis::AnalysisGraph::new(); @@ -271,7 +270,7 @@ pub fn function_throw_sets<'db>( continue; } for to in targets { - if let Some(dep_throws) = lookup_dep_throw_set(db, dep_packages, dep_interfaces, to) { + if let Some(dep_throws) = lookup_dep_throw_set(db, deps, to) { // Cross-package: merge dependency's transitive throw facts into caller's direct facts direct_facts .entry(from.clone()) @@ -295,7 +294,7 @@ pub fn function_throw_sets<'db>( continue; } for to in targets { - if lookup_dep_throw_set(db, dep_packages, dep_interfaces, to).is_none() { + if lookup_dep_throw_set(db, deps, to).is_none() { graph.add_edge(from.clone(), to.clone()); } } @@ -652,6 +651,13 @@ fn collect_member_field_call_throws<'db>( .collect() } +/// Conservative pre-inference recovery pass used only for outward-throws +/// collection. +/// +/// This intentionally duplicates a narrow slice of expression typing because +/// `function_throw_sets` runs before full scope inference and must remain +/// cycle-safe. Supported shapes should be kept explicit and conservative; full +/// TIR inference remains the source of truth for general expression semantics. struct MemberFieldCallCollector<'a, 'db> { db: &'db dyn crate::Db, res_ctx: &'a crate::package_interface::PackageResolutionContext<'db>, @@ -826,7 +832,7 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { } => { let inferred_ty = initializer.and_then(|expr_id| { self.visit_expr(expr_id, env); - self.resolve_expr_ty(expr_id, env) + self.recover_expr_ty(expr_id, env) }); if let Some((binding_name, binding_ty)) = self.resolve_binding_ty(*pattern, *type_annotation, inferred_ty) @@ -896,7 +902,7 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { _ => return, }; - let Some(base_ty) = self.resolve_expr_ty(base_id, env) else { + let Some(base_ty) = self.recover_expr_ty(base_id, env) else { return; }; let resolved_base = crate::throws_semantics::resolve_alias_chain(&base_ty, self.aliases); @@ -923,27 +929,19 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { } } - fn resolve_expr_ty( + fn recover_expr_ty( &self, expr_id: baml_compiler2_ast::ExprId, env: &HashMap, ) -> Option { match &self.body.exprs[expr_id] { - Expr::Path(segments) if segments.len() == 1 => { - let name = &segments[0]; - if name.as_str() == "self" { - self.self_ty() - } else { - env.get(name).cloned() - } - } - Expr::Path(segments) => self.named_callable_ty(segments), + Expr::Path(segments) => self.recover_path_ty(segments, env), Expr::FieldAccess { base, field } | Expr::OptionalFieldAccess { base, field } => { - let base_ty = self.resolve_expr_ty(*base, env)?; + let base_ty = self.recover_expr_ty(*base, env)?; self.resolve_member_ty(&base_ty, field) } Expr::Call { callee, .. } | Expr::OptionalCall { callee, .. } => { - let callee_ty = self.resolve_expr_ty(*callee, env)?; + let callee_ty = self.recover_expr_ty(*callee, env)?; match callee_ty { Ty::Function { ret, .. } => Some(*ret), Ty::Optional(inner, _) => match *inner { @@ -957,53 +955,76 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { type_name: Some(name), type_args, .. - } => { - let def = self.res_ctx.own_items.lookup_type(self.ns_context, name)?; - match def { - Definition::Class(_) => { - let qtn = qualify_def(self.db, def, name); - let mut diags = Vec::new(); - let lowered_type_args = type_args - .iter() - .map(|te| { - lower_type_expr_in_ns( - self.db, - te, - self.res_ctx.own_items, - self.ns_context, - self.generic_params, - &mut diags, - ) - }) - .collect(); - Some(Ty::Class( - crate::ty::NominalTypeRef::new_with_type_args(qtn, lowered_type_args), - TyAttr::default(), - )) - } - Definition::Enum(_) => { - let qtn = qualify_def(self.db, def, name); - Some(Ty::Enum( - crate::ty::NominalTypeRef::new_with_type_args(qtn, Vec::new()), - TyAttr::default(), - )) - } - _ => None, - } + } => self.recover_typed_object_ty(name, type_args), + Expr::Array { elements } => self.recover_array_ty(elements, env), + Expr::OptionalChain { expr } => self.recover_expr_ty(*expr, env), + _ => None, + } + } + + fn recover_path_ty(&self, segments: &[Name], env: &HashMap) -> Option { + if segments.len() == 1 { + let name = &segments[0]; + if name.as_str() == "self" { + self.self_ty() + } else { + env.get(name) + .cloned() + .or_else(|| self.named_callable_ty(segments)) } - Expr::Array { elements } => { - let element_tys: Vec = elements + } else { + self.named_callable_ty(segments) + } + } + + fn recover_typed_object_ty(&self, name: &Name, type_args: &[TypeExpr]) -> Option { + let def = self.res_ctx.own_items.lookup_type(self.ns_context, name)?; + match def { + Definition::Class(_) => { + let qtn = qualify_def(self.db, def, name); + let mut diags = Vec::new(); + let lowered_type_args = type_args .iter() - .map(|expr_id| self.resolve_expr_ty(*expr_id, env)) - .collect::>>()?; - let element_ty = collapse_types(element_tys)?; - Some(Ty::List(Box::new(element_ty), TyAttr::default())) + .map(|te| { + lower_type_expr_in_ns( + self.db, + te, + self.res_ctx.own_items, + self.ns_context, + self.generic_params, + &mut diags, + ) + }) + .collect(); + Some(Ty::Class( + crate::ty::NominalTypeRef::new_with_type_args(qtn, lowered_type_args), + TyAttr::default(), + )) + } + Definition::Enum(_) => { + let qtn = qualify_def(self.db, def, name); + Some(Ty::Enum( + crate::ty::NominalTypeRef::new_with_type_args(qtn, Vec::new()), + TyAttr::default(), + )) } - Expr::OptionalChain { expr } => self.resolve_expr_ty(*expr, env), _ => None, } } + fn recover_array_ty( + &self, + elements: &[baml_compiler2_ast::ExprId], + env: &HashMap, + ) -> Option { + let element_tys: Vec = elements + .iter() + .map(|expr_id| self.recover_expr_ty(*expr_id, env)) + .collect::>>()?; + let element_ty = collapse_types(element_tys)?; + Some(Ty::List(Box::new(element_ty), TyAttr::default())) + } + fn resolve_member_ty(&self, base_ty: &Ty, field: &Name) -> Option { let resolved_base = crate::throws_semantics::resolve_alias_chain(base_ty, self.aliases); let Ty::Class(class_name, _) = &resolved_base else { @@ -1018,37 +1039,14 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { let method = self .res_ctx .lookup_class_method(self.db, class_name, field)?; - Some(Ty::Function { - params: method - .function - .params - .into_iter() - .map(|(name, ty)| (Some(name), ty)) - .collect(), - ret: Box::new(method.function.return_type), - throws: Box::new(method.function.outward_throws.unwrap_or(Ty::Never { - attr: TyAttr::default(), - })), - attr: TyAttr::default(), - }) + Some(method.function.as_ty()) } fn named_callable_ty(&self, path: &[Name]) -> Option { let (_source, function) = self .res_ctx .lookup_function(self.db, path, self.ns_context)?; - Some(Ty::Function { - params: function - .params - .into_iter() - .map(|(name, ty)| (Some(name), ty)) - .collect(), - ret: Box::new(function.return_type), - throws: Box::new(function.outward_throws.unwrap_or(Ty::Never { - attr: TyAttr::default(), - })), - attr: TyAttr::default(), - }) + Some(function.as_ty()) } fn resolve_binding_ty( @@ -1075,7 +1073,7 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { ) -> Option<(Name, Ty)> { match &self.body.exprs[target] { Expr::Path(segments) if segments.len() == 1 => { - Some((segments[0].clone(), self.resolve_expr_ty(value, env)?)) + Some((segments[0].clone(), self.recover_expr_ty(value, env)?)) } _ => None, } @@ -1086,7 +1084,7 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { collection: baml_compiler2_ast::ExprId, env: &HashMap, ) -> Option { - match self.resolve_expr_ty(collection, env)? { + match self.recover_expr_ty(collection, env)? { Ty::List(inner, _) => Some(*inner), _ => None, } @@ -1151,18 +1149,14 @@ fn collapse_types(types: Vec) -> Option { /// Look up a function's transitive throw set from dependency interfaces. fn lookup_dep_throw_set<'db>( db: &'db dyn crate::Db, - dep_packages: &[(Name, baml_compiler2_hir::package::PackageId<'db>)], - dep_interfaces: &[(Name, &crate::package_interface::PackageInterface)], + deps: &[crate::package_interface::ResolvedDependency<'db>], target_name: &Name, ) -> Option> { - for (dep_name, dep_iface) in dep_interfaces { - if !dep_iface.callable_keys.contains(target_name) { + for dep in deps { + if !dep.interface.callable_keys.contains(target_name) { continue; } - let dep_pkg_id = dep_packages - .iter() - .find_map(|(pkg_name, dep_pkg_id)| (pkg_name == dep_name).then_some(*dep_pkg_id))?; - if let Some(throws) = function_throw_sets(db, dep_pkg_id).transitive_for(target_name) { + if let Some(throws) = function_throw_sets(db, dep.package_id).transitive_for(target_name) { return Some(throws.clone()); } } From 732798692bc395fdc448be6133fc00407f935707 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Fri, 10 Apr 2026 03:05:56 -0500 Subject: [PATCH 25/26] Expand throws regression coverage --- .../baml_lsp2_actions/src/type_info_tests.rs | 91 ++++++ .../member_receiver_shapes.baml | 47 +++ .../optional_callback_direct.baml | 12 + ...ows__01_lexer__member_receiver_shapes.snap | 208 ++++++++++++ ...s__01_lexer__optional_callback_direct.snap | 73 +++++ ...ws__02_parser__member_receiver_shapes.snap | 308 ++++++++++++++++++ ...__02_parser__optional_callback_direct.snap | 96 ++++++ ...l_tests__function_type_throws__03_hir.snap | 38 +++ ...tests__function_type_throws__04_5_mir.snap | 198 +++++++++++ ...l_tests__function_type_throws__04_tir.snap | 80 ++++- ...sts__function_type_throws__06_codegen.snap | 89 +++++ ..._10_formatter__member_receiver_shapes.snap | 6 + ...0_formatter__optional_callback_direct.snap | 16 + 13 files changed, 1260 insertions(+), 2 deletions(-) create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/member_receiver_shapes.baml create mode 100644 baml_language/crates/baml_tests/projects/function_type_throws/optional_callback_direct.baml create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__member_receiver_shapes.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__optional_callback_direct.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__member_receiver_shapes.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__optional_callback_direct.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__member_receiver_shapes.snap create mode 100644 baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__optional_callback_direct.snap diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs index 6988d40943..d6107bc432 100644 --- a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs +++ b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs @@ -418,4 +418,95 @@ class Runner { "Method hover should preserve the declared throws clause, got: {md}" ); } + + #[test] + fn hover_function_with_inferred_throws_from_returned_receiver_call() { + let test = CursorTest::new( + r#" +class Handler { + run: () -> null throws E +} + +function make_handler() -> Handler { + Handler { run: () -> null { throw "boom" } } +} + +function <[CURSOR]use_returned_receiver() -> null { + make_handler().run() +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert_eq!(throws.as_deref(), Some("string")); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + md.contains("throws string"), + "Hover should show inferred throws from returned receiver call, got: {md}" + ); + } + + #[test] + fn hover_method_with_inferred_throws_from_self_field_call() { + let test = CursorTest::new( + r#" +class Handler { + run: () -> null throws E +} + +class Runner { + primary: Handler + + function <[CURSOR]use_self(self) -> null { + self.primary.run() + } +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert_eq!(throws.as_deref(), Some("string")); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + md.contains("throws string"), + "Method hover should show inferred throws from self field call, got: {md}" + ); + } + + #[test] + fn hover_function_with_qualified_declared_throws_preserves_source_clause() { + let test = CursorTest::new( + r#" +namespace root.errors { + class Io { + message: string + } +} + +function <[CURSOR]qualified() -> null throws root.errors.Io | string { + null +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert_eq!(throws.as_deref(), Some("root.errors.Io | string")); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + md.contains("throws root.errors.Io | string"), + "Hover should preserve the qualified declared throws clause, got: {md}" + ); + } } diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/member_receiver_shapes.baml b/baml_language/crates/baml_tests/projects/function_type_throws/member_receiver_shapes.baml new file mode 100644 index 0000000000..eb34243c9e --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/member_receiver_shapes.baml @@ -0,0 +1,47 @@ +// === Additional member receiver throws propagation shapes === + +class EdgeHandler { + run: () -> null throws E +} + +class EdgeHolder { + item: T +} + +type EdgeStringHandler = EdgeHandler + +class EdgeService { + primary: EdgeHandler + + function call_self_field(self) -> null { + self.primary.run() + } + + function call_self_alias(self) -> null { + let me = self + me.primary.run() + } +} + +function make_edge_handler() -> EdgeHandler { + EdgeHandler { run: () -> null { throw "edge boom" } } +} + +function call_returned_receiver() -> null { + make_edge_handler().run() +} + +function call_typed_local_receiver() -> null { + let h: EdgeHandler = make_edge_handler() + h.run() +} + +function call_nested_alias_receiver(h: EdgeHandler) -> null { + let a = h + let b = a + b.run() +} + +function call_field_chained_receiver(holder: EdgeHolder) -> null { + holder.item.run() +} diff --git a/baml_language/crates/baml_tests/projects/function_type_throws/optional_callback_direct.baml b/baml_language/crates/baml_tests/projects/function_type_throws/optional_callback_direct.baml new file mode 100644 index 0000000000..b16abe8de8 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/function_type_throws/optional_callback_direct.baml @@ -0,0 +1,12 @@ +// === Optional direct callback invocation with omitted throws === + +function optional_apply(cb: (() -> int)?) -> int? { + cb?.() +} + +function optional_apply_with_body(cb: (() -> int)?) -> int? { + if (cb == null) { + throw "missing callback" + } + cb() +} diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__member_receiver_shapes.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__member_receiver_shapes.snap new file mode 100644 index 0000000000..6a8496a623 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__member_receiver_shapes.snap @@ -0,0 +1,208 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +assertion_line: 11289 +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Additional" +Word "member" +Word "receiver" +Throws "throws" +Word "propagation" +Word "shapes" +EqualsEquals "==" +Equals "=" +Class "class" +Word "EdgeHandler" +Less "<" +Word "E" +Greater ">" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +Throws "throws" +Word "E" +RBrace "}" +Class "class" +Word "EdgeHolder" +Less "<" +Word "T" +Greater ">" +LBrace "{" +Word "item" +Colon ":" +Word "T" +RBrace "}" +Word "type" +Word "EdgeStringHandler" +Equals "=" +Word "EdgeHandler" +Less "<" +Word "string" +Greater ">" +Class "class" +Word "EdgeService" +LBrace "{" +Word "primary" +Colon ":" +Word "EdgeHandler" +Less "<" +Word "string" +Greater ">" +Function "function" +Word "call_self_field" +LParen "(" +Word "self" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "self" +Dot "." +Word "primary" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "call_self_alias" +LParen "(" +Word "self" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "me" +Equals "=" +Word "self" +Word "me" +Dot "." +Word "primary" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +RBrace "}" +Function "function" +Word "make_edge_handler" +LParen "(" +RParen ")" +Arrow "->" +Word "EdgeHandler" +Less "<" +Word "string" +Greater ">" +LBrace "{" +Word "EdgeHandler" +LBrace "{" +Word "run" +Colon ":" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Throw "throw" +Quote "\"" +Word "edge" +Word "boom" +Quote "\"" +RBrace "}" +RBrace "}" +RBrace "}" +Function "function" +Word "call_returned_receiver" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "make_edge_handler" +LParen "(" +RParen ")" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "call_typed_local_receiver" +LParen "(" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "h" +Colon ":" +Word "EdgeHandler" +Less "<" +Word "string" +Greater ">" +Equals "=" +Word "make_edge_handler" +LParen "(" +RParen ")" +Word "h" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "call_nested_alias_receiver" +LParen "(" +Word "h" +Colon ":" +Word "EdgeHandler" +Less "<" +Word "string" +Greater ">" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Let "let" +Word "a" +Equals "=" +Word "h" +Let "let" +Word "b" +Equals "=" +Word "a" +Word "b" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "call_field_chained_receiver" +LParen "(" +Word "holder" +Colon ":" +Word "EdgeHolder" +Less "<" +Word "EdgeStringHandler" +Greater ">" +RParen ")" +Arrow "->" +Word "null" +LBrace "{" +Word "holder" +Dot "." +Word "item" +Dot "." +Word "run" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__optional_callback_direct.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__optional_callback_direct.snap new file mode 100644 index 0000000000..574348f10f --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__01_lexer__optional_callback_direct.snap @@ -0,0 +1,73 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +assertion_line: 11381 +--- +Slash "/" +Slash "/" +EqualsEquals "==" +Equals "=" +Word "Optional" +Word "direct" +Word "callback" +Word "invocation" +Word "with" +Word "omitted" +Throws "throws" +EqualsEquals "==" +Equals "=" +Function "function" +Word "optional_apply" +LParen "(" +Word "cb" +Colon ":" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Question "?" +RParen ")" +Arrow "->" +Word "int" +Question "?" +LBrace "{" +Word "cb" +QuestionDot "?." +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "optional_apply_with_body" +LParen "(" +Word "cb" +Colon ":" +LParen "(" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +RParen ")" +Question "?" +RParen ")" +Arrow "->" +Word "int" +Question "?" +LBrace "{" +If "if" +LParen "(" +Word "cb" +EqualsEquals "==" +Word "null" +RParen ")" +LBrace "{" +Throw "throw" +Quote "\"" +Word "missing" +Word "callback" +Quote "\"" +RBrace "}" +Word "cb" +LParen "(" +RParen ")" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__member_receiver_shapes.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__member_receiver_shapes.snap new file mode 100644 index 0000000000..a8d008182a --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__member_receiver_shapes.snap @@ -0,0 +1,308 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +assertion_line: 12246 +--- +=== SYNTAX TREE === +SOURCE_FILE + CLASS_DEF + KW_CLASS "class" + WORD "EdgeHandler" + GENERIC_PARAM_LIST + LESS "<" + GENERIC_PARAM "E" + WORD "E" + GREATER ">" + L_BRACE "{" + FIELD + WORD "run" + COLON ":" + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + THROWS_CLAUSE + KW_THROWS "throws" + TYPE_EXPR "E" + WORD "E" + R_BRACE "}" + CLASS_DEF + KW_CLASS "class" + WORD "EdgeHolder" + GENERIC_PARAM_LIST + LESS "<" + GENERIC_PARAM "T" + WORD "T" + GREATER ">" + L_BRACE "{" + FIELD + WORD "item" + COLON ":" + TYPE_EXPR "T" + WORD "T" + R_BRACE "}" + TYPE_ALIAS_DEF + WORD "type" + WORD "EdgeStringHandler" + EQUALS "=" + TYPE_EXPR + WORD "EdgeHandler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + CLASS_DEF + KW_CLASS "class" + WORD "EdgeService" + L_BRACE "{" + FIELD + WORD "primary" + COLON ":" + TYPE_EXPR + WORD "EdgeHandler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "call_self_field" + PARAMETER_LIST + L_PAREN "(" + PARAMETER "self" + WORD "self" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + PATH_EXPR "self.primary.run" + WORD "self" + DOT "." + WORD "primary" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "call_self_alias" + PARAMETER_LIST + L_PAREN "(" + PARAMETER "self" + WORD "self" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let me = self" + KW_LET "let" + WORD "me" + EQUALS "=" + WORD "self" + CALL_EXPR + PATH_EXPR "me.primary.run" + WORD "me" + DOT "." + WORD "primary" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "make_edge_handler" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR + WORD "EdgeHandler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OBJECT_LITERAL + WORD "EdgeHandler" + L_BRACE "{" + OBJECT_FIELD + WORD "run" + COLON ":" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "edge boom" + QUOTE """ + WORD "edge" + WORD "boom" + QUOTE """ + R_BRACE "}" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "call_returned_receiver" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + FIELD_ACCESS_EXPR + CALL_EXPR + WORD "make_edge_handler" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "call_typed_local_receiver" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "h" + COLON ":" + TYPE_EXPR + WORD "EdgeHandler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + EQUALS "=" + CALL_EXPR + WORD "make_edge_handler" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + CALL_EXPR + PATH_EXPR "h.run" + WORD "h" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "call_nested_alias_receiver" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "h" + COLON ":" + TYPE_EXPR + WORD "EdgeHandler" + TYPE_ARGS + LESS "<" + TYPE_EXPR "string" + WORD "string" + GREATER ">" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let a = h" + KW_LET "let" + WORD "a" + EQUALS "=" + WORD "h" + LET_STMT "let b = a" + KW_LET "let" + WORD "b" + EQUALS "=" + WORD "a" + CALL_EXPR + PATH_EXPR "b.run" + WORD "b" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "call_field_chained_receiver" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "holder" + COLON ":" + TYPE_EXPR + WORD "EdgeHolder" + TYPE_ARGS + LESS "<" + TYPE_EXPR "EdgeStringHandler" + WORD "EdgeStringHandler" + GREATER ">" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "null" + WORD "null" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + CALL_EXPR + PATH_EXPR "holder.item.run" + WORD "holder" + DOT "." + WORD "item" + DOT "." + WORD "run" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__optional_callback_direct.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__optional_callback_direct.snap new file mode 100644 index 0000000000..90ce234377 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__02_parser__optional_callback_direct.snap @@ -0,0 +1,96 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +assertion_line: 12402 +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "optional_apply" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "cb" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + QUESTION "?" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int?" + WORD "int" + QUESTION "?" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + OPTIONAL_CALL_EXPR + WORD "cb" + QUESTION_DOT "?." + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "optional_apply_with_body" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "cb" + COLON ":" + TYPE_EXPR + L_PAREN "(" + FUNCTION_TYPE_PARAM + TYPE_EXPR + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + QUESTION "?" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int?" + WORD "int" + QUESTION "?" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR + L_PAREN "(" + BINARY_EXPR "cb == null" + WORD "cb" + EQUALS_EQUALS "==" + WORD "null" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + THROW_STMT + THROW_EXPR + KW_THROW "throw" + STRING_LITERAL "missing callback" + QUOTE """ + WORD "missing" + WORD "callback" + QUOTE """ + R_BRACE "}" + CALL_EXPR + WORD "cb" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap index 2ef8a7ec30..0a5fe29366 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__03_hir.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 12712 --- === HIR2 === function user.test_array_map_pure() -> int[] [expr] { @@ -444,6 +445,37 @@ function user.test_typed_lambda_annotation_underdeclares_body() -> int [expr] { function user.test_typed_lambda_throws_mismatch() -> int [expr] { { let f: () -> int = () -> int { { throw 42 } } } f() } +class user.EdgeHandler { + run: () -> null +} +class user.EdgeHolder { + item: user.T +} +class user.EdgeService { + primary: user.EdgeHandler +} +type user.EdgeStringHandler = user.EdgeHandler +function user.call_field_chained_receiver(holder: user.EdgeHolder) -> null [expr] { + { } holder.item.run() +} +function user.call_nested_alias_receiver(h: user.EdgeHandler) -> null [expr] { + { let a = h; let b = a } b.run() +} +function user.call_returned_receiver() -> null [expr] { + { } make_edge_handler().run() +} +function user.call_self_alias(self: ?) -> null [expr] { + { let me = self } me.primary.run() +} +function user.call_self_field(self: ?) -> null [expr] { + { } self.primary.run() +} +function user.call_typed_local_receiver() -> null [expr] { + { let h: user.EdgeHandler = make_edge_handler() } h.run() +} +function user.make_edge_handler() -> user.EdgeHandler [expr] { + { } user.EdgeHandler { run: () -> null { { throw "edge boom" } } } +} function user.apply_with_arg(x: int, f: (int) -> int) -> int [expr] { { } f(x) } @@ -477,6 +509,12 @@ function user.optional_call_rethrows(cb: () -> int?) -> int? [expr] { function user.test_optional_call_with_throwing_callback() -> int? [expr] { { } optional_call_rethrows(() -> int { { throw "boom" } }) } +function user.optional_apply(cb: () -> int?) -> int? [expr] { + { } cb?.() +} +function user.optional_apply_with_body(cb: () -> int?) -> int? [expr] { + { if (cb Eq null) { throw "missing callback" } } cb() +} function user.make_pure() -> () -> int [expr] { { return () -> int { { } 42 } } } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index a73c4bdc9e..0ebb68fbfa 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 13184 --- === MIR2 === fn user.test_array_map_pure() -> int[] { @@ -3332,6 +3333,154 @@ fn .() -> null { } } +fn user.call_field_chained_receiver(holder: EdgeHolder) -> null { + // Locals: + let _0: null // _0 // return + let _1: EdgeHolder // holder // param + let _2: () -> null throws string + let _3: EdgeHandler + + bb0: { + _3 = copy _1.0; + _2 = copy _3.0; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.call_nested_alias_receiver(h: EdgeHandler) -> null { + // Locals: + let _0: null // _0 // return + let _1: EdgeHandler // h // param + let _2: EdgeHandler // a + let _3: EdgeHandler // b + let _4: () -> null throws string + let _5: EdgeHandler + + bb0: { + _2 = copy _1; + _3 = copy _2; + _5 = copy _3; + _4 = copy _5.0; + _0 = call copy _4() -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.call_returned_receiver() -> null { + // Locals: + let _0: null // _0 // return + let _1: () -> null throws string + let _2: EdgeHandler + + bb0: { + _2 = call const fn user.make_edge_handler() -> [bb1]; + } + + bb1: { + _1 = copy _2.0; + _0 = call copy _1() -> [bb2]; + } + + bb2: { + return; + } +} + +fn user.EdgeService.call_self_alias(self: EdgeService) -> null { + // Locals: + let _0: null // _0 // return + let _1: EdgeService // self // param + let _2: EdgeService // me + let _3: () -> null throws string + let _4: EdgeHandler + let _5: EdgeService + + bb0: { + _2 = copy _1; + _5 = copy _2; + _4 = copy _5.0; + _3 = copy _4.0; + _0 = call copy _3() -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.EdgeService.call_self_field(self: EdgeService) -> null { + // Locals: + let _0: null // _0 // return + let _1: EdgeService // self // param + let _2: () -> null throws string + let _3: EdgeHandler + + bb0: { + _3 = copy _1.0; + _2 = copy _3.0; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + return; + } +} + +fn user.call_typed_local_receiver() -> null { + // Locals: + let _0: null // _0 // return + let _1: EdgeHandler // h + let _2: () -> null throws string + let _3: EdgeHandler + + bb0: { + _1 = call const fn user.make_edge_handler() -> [bb1]; + } + + bb1: { + _3 = copy _1; + _2 = copy _3.0; + _0 = call copy _2() -> [bb2]; + } + + bb2: { + return; + } +} + +fn user.make_edge_handler() -> EdgeHandler { + // Locals: + let _0: EdgeHandler // _0 // return + let _1: () -> void throws string + + bb0: { + _1 = make_closure lambda[0](); + _0 = EdgeHandler { copy _1 }; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + throw const "edge boom"; + } +} + fn user.apply_with_arg(x: int, f: (int) -> int) -> int { // Locals: let _0: int // _0 // return @@ -3685,6 +3834,55 @@ fn .() -> null { } } +fn user.optional_apply(cb: (() -> int)?) -> int? { + // Locals: + let _0: int? // _0 // return + let _1: (() -> int)? // cb // param + let _2: bool + + bb0: { + _2 = copy _1 == const null; + branch copy _2 -> [bb2, bb1]; + } + + bb1: { + _0 = call copy _1() -> [bb3]; + } + + bb2: { + _0 = const null; + goto -> bb3; + } + + bb3: { + return; + } +} + +fn user.optional_apply_with_body(cb: (() -> int)?) -> int? { + // Locals: + let _0: int? // _0 // return + let _1: (() -> int)? // cb // param + let _2: bool + + bb0: { + _2 = copy _1 == const null; + branch copy _2 -> [bb3, bb1]; + } + + bb1: { + _0 = call copy _1() -> [bb2]; + } + + bb2: { + return; + } + + bb3: { + throw const "missing callback"; + } +} + fn user.make_pure() -> () -> int { // Locals: let _0: () -> int // _0 // return diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 4c0db4fdaa..6e5358ae85 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 12944 --- === TIR2 === function user.test_array_map_pure() -> int[] throws never { @@ -747,7 +748,7 @@ function user.test_string_handler_alias() -> null throws string { use_string_handler_alias(make_string_handler()) : null } } -function user.test_local_string_handler() -> null throws never { +function user.test_local_string_handler() -> null throws string { { : null use_local_string_handler() : null } @@ -843,7 +844,7 @@ function user.test_run_tasks_alias() -> null throws int | string { run_tasks_alias(make_mixed_tasks()) : null } } -function user.test_run_local_tasks() -> null throws never { +function user.test_run_local_tasks() -> null throws int | string { { : null run_local_tasks() : null } @@ -1300,6 +1301,67 @@ lambda user.test_typed_lambda_annotation_underdeclares_body { !! 654..661: throws contract violation: `string` is missing int ?? 654..661: extraneous throws declaration: string } +class user.EdgeHandler { + run: () -> null throws E +} +class user.EdgeHolder { + item: T +} +type user.EdgeStringHandler = user.EdgeHandler +class user.EdgeService { + primary: user.EdgeHandler +} +function user.EdgeService.call_self_field(self: user.EdgeService) -> null throws string { + { : null + self.primary.run() : null + } +} +function user.EdgeService.call_self_alias(self: user.EdgeService) -> null throws string { + { : null + let me = self : user.EdgeService + me.primary.run() : null + } +} +function user.make_edge_handler() -> user.EdgeHandler throws never { + { : user.EdgeHandler + EdgeHandler { run: () -> null { ... } } : user.EdgeHandler + } +} +lambda user.make_edge_handler { +} +function user.call_returned_receiver() -> null throws string { + { : null + make_edge_handler().run() : null + } +} +function user.call_typed_local_receiver() -> null throws string { + { : null + let h = make_edge_handler() : user.EdgeHandler + h.run() : null + } +} +function user.call_nested_alias_receiver(h: user.EdgeHandler) -> null throws string { + { : null + let a = h : user.EdgeHandler + let b = a : user.EdgeHandler + b.run() : null + } +} +function user.call_field_chained_receiver(holder: user.EdgeHolder) -> null throws string { + { : null + holder.item.run() : null + } +} +class user.EdgeHandler$stream { + run: unknown +} +class user.EdgeHolder$stream { + item: null | unknown +} +type user.EdgeStringHandler$stream = user.EdgeHandler$stream +class user.EdgeService$stream { + primary: null | user.EdgeHandler$stream +} function user.apply_with_arg(x: int, f: (int) -> int throws __throws_f) -> int throws __throws_f { { : int f(x) : int @@ -1431,6 +1493,20 @@ function user.test_optional_call_with_throwing_callback() -> int? throws string } lambda user.test_optional_call_with_throwing_callback { } +function user.optional_apply(cb: (() -> int)?) -> int? throws never { + { : int? + cb?.() : int? + } +} +function user.optional_apply_with_body(cb: (() -> int)?) -> int? throws string { + { : int + if (cb == null : bool) : void + { : never + throw "missing callback" : "missing callback" + } + cb() : int + } +} function user.make_pure() -> () -> int throws never { { : never return : () -> 42 diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap index 9769e235e1..c3d842267b 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__06_codegen.snap @@ -1,6 +1,23 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 13683 --- +function user.EdgeService.call_self_alias(self: null) -> null { + load_var self + load_field .primary + load_field .run + call_indirect + return +} + +function user.EdgeService.call_self_field(self: null) -> null { + load_var self + load_field .primary + load_field .run + call_indirect + return +} + function user.MethodRunner.apply(self: null, f: (void) -> int) -> int { load_var self load_field .value @@ -163,6 +180,35 @@ function user.bad_box_mismatch() -> null { return } +function user.call_field_chained_receiver(holder: EdgeHolder) -> null { + load_var holder + load_field .item + load_field .run + call_indirect + return +} + +function user.call_nested_alias_receiver(h: EdgeHandler) -> null { + load_var h + load_field .run + call_indirect + return +} + +function user.call_returned_receiver() -> null { + call user.make_edge_handler + load_field .run + call_indirect + return +} + +function user.call_typed_local_receiver() -> null { + call user.make_edge_handler + load_field .run + call_indirect + return +} + function user.caller() -> string { load_const 1 call user.declared_may_fail @@ -341,6 +387,13 @@ function user.make_class_handler() -> ClassThrowingHandler { return } +function user.make_edge_handler() -> EdgeHandler { + alloc_instance EdgeHandler + make_closure ., 0 + init_field .run + return +} + function user.make_enum_handler() -> EnumThrowingHandler { alloc_instance EnumThrowingHandler make_closure ., 0 @@ -455,6 +508,42 @@ function user.maybe_use_string_handler(h: Handler?) -> null { return } +function user.optional_apply(cb: (() -> int)?) -> int? { + load_var cb + load_const null + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_var cb + call_indirect + jump L2 + + L1: + load_const null + + L2: + return +} + +function user.optional_apply_with_body(cb: (() -> int)?) -> int? { + load_var cb + load_const null + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_var cb + call_indirect + return + + L1: + load_const "missing callback" + throw +} + function user.optional_call_caught(cb: (() -> int throws string)?) -> int? { load_var cb load_const null diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__member_receiver_shapes.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__member_receiver_shapes.snap new file mode 100644 index 0000000000..4be83e7b64 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__member_receiver_shapes.snap @@ -0,0 +1,6 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +assertion_line: 14811 +--- +=== STRONG AST ERROR === +An element at member_receiver_shapes.baml:3:18 was a token when it should have been a node. diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__optional_callback_direct.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__optional_callback_direct.snap new file mode 100644 index 0000000000..9ed10fa304 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__10_formatter__optional_callback_direct.snap @@ -0,0 +1,16 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +assertion_line: 15039 +--- +// === Optional direct callback invocation with omitted throws === + +function optional_apply(cb: (() -> int)?) -> int? { + cb?.() +} + +function optional_apply_with_body(cb: (() -> int)?) -> int? { + if (cb == null) { + throw "missing callback" + } + cb() +} From 2d58aee3916b256a0ab5069fb3ae0e588cad8299 Mon Sep 17 00:00:00 2001 From: rossirpaulo Date: Thu, 16 Apr 2026 01:42:52 -0500 Subject: [PATCH 26/26] Refine typed throws inference for higher-order callbacks Preserve generic payloads for inline collections and optional calls so stored callbacks keep their inferred throws, and refresh hover and snapshot coverage to reflect the new behavior. Made-with: Cursor --- .../baml_std/baml/containers.baml | 20 ++- .../crates/baml_compiler2_tir/src/builder.rs | 76 ++++++++- .../crates/baml_compiler2_tir/src/generics.rs | 146 ++++++++++++++++ .../baml_compiler2_tir/src/throw_inference.rs | 158 ++++++++++++++++-- .../crates/baml_lsp2_actions/src/type_info.rs | 42 ++++- .../baml_lsp2_actions/src/type_info_tests.rs | 116 +++++++++++++ .../baml_tests____baml_std____03_hir.snap | 4 +- .../baml_tests____baml_std____04_5_mir.snap | 92 +++++++++- .../baml_tests____baml_std____04_tir.snap | 24 ++- .../baml_tests____baml_std____06_codegen.snap | 70 +++++++- ...tests__function_type_throws__04_5_mir.snap | 5 +- ...l_tests__function_type_throws__04_tir.snap | 12 +- ..._function_type_throws__05_diagnostics.snap | 10 -- .../baml_tests/src/compiler2_tir/phase6.rs | 28 ++++ 14 files changed, 736 insertions(+), 67 deletions(-) diff --git a/baml_language/crates/baml_builtins2/baml_std/baml/containers.baml b/baml_language/crates/baml_builtins2/baml_std/baml/containers.baml index efd0caa39b..a997f10bc6 100644 --- a/baml_language/crates/baml_builtins2/baml_std/baml/containers.baml +++ b/baml_language/crates/baml_builtins2/baml_std/baml/containers.baml @@ -72,11 +72,19 @@ class Map { $rust_function } - function map_keys(self, f: (K) -> U throws never) -> U[] { - self.keys().map(f) - } - - function map_values(self, f: (V) -> U throws never) -> U[] { - self.values().map(f) + function map_keys(self, f: (K) -> U) -> U[] { + let result: U[] = [] + for (let key in self.keys()) { + result.push(f(key)) + } + result + } + + function map_values(self, f: (V) -> U) -> U[] { + let result: U[] = [] + for (let value in self.values()) { + result.push(f(value)) + } + result } } diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index 0ed568e303..38b46fce6a 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -19,6 +19,7 @@ use baml_base::{Name, SourceFile}; use baml_compiler2_ast::{Expr, ExprBody, ExprId, PatId, Stmt, StmtId, TypeExpr}; use baml_compiler2_hir::{ contributions::Definition, + file_item_tree, package::{PackageId, PackageItems}, scope::ScopeId, }; @@ -386,7 +387,8 @@ impl<'db> TypeInferenceBuilder<'db> { Expr::Array { elements } => { let elem_types: Vec = elements.iter().map(|e| self.infer_expr(*e, body)).collect(); - let elem_ty = Self::join_all(&elem_types).widen_fresh(); + let elem_ty = + crate::generics::merge_collection_member_types(elem_types).widen_fresh(); Ty::List(Box::new(elem_ty), TyAttr::default()) } Expr::Map { entries } => { @@ -488,7 +490,10 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expr(*field_expr, body); } } - let ty = resolved_ty.expect("resolved class type"); + let resolved_ty = resolved_ty.expect("resolved class type"); + let ty = self + .infer_object_literal_type_args(&resolved_ty, fields.as_slice()) + .unwrap_or(resolved_ty); self.record_expr_type(expr_id, ty.clone()); ty } else { @@ -956,15 +961,34 @@ impl<'db> TypeInferenceBuilder<'db> { return result_ty; } - let is_method_call = matches!(&body.exprs[*callee], Expr::FieldAccess { .. }); + let is_method_call = matches!( + &body.exprs[*callee], + Expr::FieldAccess { .. } | Expr::OptionalFieldAccess { .. } + ); + let optional_chain_call = + self.in_optional_chain > 0 && Self::expr_contains_optional(*callee, body); let callee_ty = self.infer_expr(*callee, body); // Expand type alias chains so alias-over-function types are callable. // Bare alias cycles are already caught by find_invalid_alias_cycles (Tarjan SCC) // before we reach here, so the depth guard is cheap insurance. let callee_ty = throws_semantics::resolve_alias_chain(&callee_ty, &self.aliases); + let callable_ty = if optional_chain_call { + crate::narrowing::remove_null(&callee_ty) + } else { + callee_ty.clone() + }; + + if optional_chain_call && matches!(&callable_ty, Ty::Never { .. }) { + for arg in args { + self.infer_expr(*arg, body); + } + let ty = Ty::Primitive(PrimitiveType::Null, TyAttr::default()); + self.record_expr_type(expr_id, ty.clone()); + return ty; + } - match &callee_ty { + match &callable_ty { Ty::Function { params, ret, @@ -1074,6 +1098,12 @@ impl<'db> TypeInferenceBuilder<'db> { self.record_expr_type(*callee, substituted_callee_ty); } + let result = if optional_chain_call { + Self::make_optional(result) + } else { + result + }; + // Subtype check against expected type (skip if we did generic // inference — the inference already accounts for expected) if bindings.is_empty() @@ -1106,7 +1136,7 @@ impl<'db> TypeInferenceBuilder<'db> { _ => { self.context.report_simple( TirTypeError::NotCallable { - ty: callee_ty.clone(), + ty: callable_ty.clone(), }, expr_id, ); @@ -2319,6 +2349,42 @@ impl<'db> TypeInferenceBuilder<'db> { None } + fn infer_object_literal_type_args( + &self, + class_ty: &Ty, + fields: &[(Name, ExprId)], + ) -> Option { + let Ty::Class(class_name, _) = &class_ty else { + return None; + }; + if !class_name.type_args().is_empty() { + return None; + } + + let class_loc = self.resolve_class_loc(class_name.qtn())?; + let item_tree = file_item_tree(self.context.db(), class_loc.file(self.context.db())); + let class_data = &item_tree[class_loc.id(self.context.db())]; + if class_data.generic_params.is_empty() { + return None; + } + + let field_types = self.lookup_class_fields(class_name); + let field_pairs: Vec<(Ty, Ty)> = fields + .iter() + .filter_map(|(field_name, field_expr)| { + let declared_ty = field_types.get(field_name)?; + let actual_ty = self.expressions.get(field_expr)?.clone().widen_fresh(); + Some((declared_ty.clone(), actual_ty)) + }) + .collect(); + + crate::generics::infer_nominal_type_args_from_fields( + class_ty, + &class_data.generic_params, + &field_pairs, + ) + } + fn literal_case_name(lit: &baml_base::Literal) -> String { match lit { baml_base::Literal::Int(v) => v.to_string(), diff --git a/baml_language/crates/baml_compiler2_tir/src/generics.rs b/baml_language/crates/baml_compiler2_tir/src/generics.rs index 5b7886632e..d846389b9f 100644 --- a/baml_language/crates/baml_compiler2_tir/src/generics.rs +++ b/baml_language/crates/baml_compiler2_tir/src/generics.rs @@ -42,6 +42,152 @@ pub fn bind_type_vars(generic_params: &[Name], concrete_args: &[Ty]) -> FxHashMa bindings } +/// Infer nominal type arguments from already-typed object fields. +/// +/// Callers pass the bare nominal type plus `(declared_field_ty, actual_field_ty)` +/// pairs. If every generic parameter is bound from the provided fields, returns +/// the nominal type with inferred type arguments applied. +pub fn infer_nominal_type_args_from_fields( + nominal_ty: &Ty, + generic_params: &[Name], + field_pairs: &[(Ty, Ty)], +) -> Option { + if generic_params.is_empty() || field_pairs.is_empty() { + return None; + } + + let mut bindings = FxHashMap::default(); + for (declared_ty, actual_ty) in field_pairs { + infer_bindings(declared_ty, actual_ty, &mut bindings); + } + + if !generic_params + .iter() + .all(|param| bindings.contains_key(param)) + { + return None; + } + + Some( + nominal_ty.clone().with_nominal_type_args( + generic_params + .iter() + .filter_map(|param| bindings.get(param).cloned()) + .collect(), + ), + ) +} + +/// Merge collection element members while preserving shared nominal payload. +/// +/// This keeps `Task` and `Task` from erasing to a bare `Task` or a +/// non-callable union when they are joined as array elements. +pub fn merge_collection_member_types(types: Vec) -> Ty { + let mut members = Vec::new(); + for ty in types { + push_collection_member(&mut members, ty); + } + + match members.len() { + 0 => Ty::Never { + attr: TyAttr::default(), + }, + 1 => members.pop().expect("single merged collection member"), + _ => Ty::Union(members, TyAttr::default()), + } +} + +fn push_collection_member(members: &mut Vec, ty: Ty) { + match ty { + Ty::Never { .. } => {} + Ty::Union(inner, _) => { + for member in inner { + push_collection_member(members, member); + } + } + other => { + if let Some(index) = members + .iter() + .position(|existing| same_nominal_identity(existing, &other)) + { + members[index] = merge_same_nominal_type_args(&members[index], &other) + .expect("same_nominal_identity implies mergeable nominal types"); + } else if !members.contains(&other) { + members.push(other); + } + } + } +} + +fn same_nominal_identity(a: &Ty, b: &Ty) -> bool { + match (a, b) { + (Ty::Class(a_nominal, _), Ty::Class(b_nominal, _)) => { + a_nominal.qtn() == b_nominal.qtn() + && a_nominal.type_args().len() == b_nominal.type_args().len() + } + (Ty::Enum(a_nominal, _), Ty::Enum(b_nominal, _)) => { + a_nominal.qtn() == b_nominal.qtn() + && a_nominal.type_args().len() == b_nominal.type_args().len() + } + (Ty::TypeAlias(a_nominal, _), Ty::TypeAlias(b_nominal, _)) => { + a_nominal.qtn() == b_nominal.qtn() + && a_nominal.type_args().len() == b_nominal.type_args().len() + } + _ => false, + } +} + +fn merge_same_nominal_type_args(a: &Ty, b: &Ty) -> Option { + match (a, b) { + (Ty::Class(a_nominal, _), Ty::Class(b_nominal, _)) + if a_nominal.qtn() == b_nominal.qtn() + && a_nominal.type_args().len() == b_nominal.type_args().len() => + { + Some( + a.clone().with_nominal_type_args( + a_nominal + .type_args() + .iter() + .zip(b_nominal.type_args().iter()) + .map(|(a_arg, b_arg)| union_ty(a_arg, b_arg)) + .collect(), + ), + ) + } + (Ty::Enum(a_nominal, _), Ty::Enum(b_nominal, _)) + if a_nominal.qtn() == b_nominal.qtn() + && a_nominal.type_args().len() == b_nominal.type_args().len() => + { + Some( + a.clone().with_nominal_type_args( + a_nominal + .type_args() + .iter() + .zip(b_nominal.type_args().iter()) + .map(|(a_arg, b_arg)| union_ty(a_arg, b_arg)) + .collect(), + ), + ) + } + (Ty::TypeAlias(a_nominal, _), Ty::TypeAlias(b_nominal, _)) + if a_nominal.qtn() == b_nominal.qtn() + && a_nominal.type_args().len() == b_nominal.type_args().len() => + { + Some( + a.clone().with_nominal_type_args( + a_nominal + .type_args() + .iter() + .zip(b_nominal.type_args().iter()) + .map(|(a_arg, b_arg)| union_ty(a_arg, b_arg)) + .collect(), + ), + ) + } + _ => None, + } +} + // ── Type substitution ───────────────────────────────────────────────────────── /// Substitute type variables in a `Ty` using the provided bindings. diff --git a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs index 7b7b49a68d..c8fb20001b 100644 --- a/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/throw_inference.rs @@ -11,6 +11,7 @@ use baml_base::Name; use baml_compiler2_ast::{Expr, ExprBody, Literal, Pattern, TypeExpr}; use baml_compiler2_hir::{ contributions::Definition, + file_item_tree, package::{PackageId, PackageItems, package_items}, }; @@ -905,7 +906,13 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { let Some(base_ty) = self.recover_expr_ty(base_id, env) else { return; }; - let resolved_base = crate::throws_semantics::resolve_alias_chain(&base_ty, self.aliases); + let resolved_base = + match crate::throws_semantics::resolve_alias_chain(&base_ty, self.aliases) { + Ty::Optional(inner, _) => { + crate::throws_semantics::resolve_alias_chain(inner.as_ref(), self.aliases) + } + other => other, + }; let Ty::Class(class_name, _) = &resolved_base else { return; }; @@ -935,6 +942,16 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { env: &HashMap, ) -> Option { match &self.body.exprs[expr_id] { + Expr::Literal(lit) => Some(match lit { + Literal::String(_) => Ty::Primitive(PrimitiveType::String, TyAttr::default()), + Literal::Int(_) => Ty::Primitive(PrimitiveType::Int, TyAttr::default()), + Literal::Float(_) => Ty::Primitive(PrimitiveType::Float, TyAttr::default()), + Literal::Bool(_) => Ty::Primitive(PrimitiveType::Bool, TyAttr::default()), + }), + Expr::ByteStringLiteral(_) => { + Some(Ty::Primitive(PrimitiveType::Uint8Array, TyAttr::default())) + } + Expr::Null => Some(Ty::Primitive(PrimitiveType::Null, TyAttr::default())), Expr::Path(segments) => self.recover_path_ty(segments, env), Expr::FieldAccess { base, field } | Expr::OptionalFieldAccess { base, field } => { let base_ty = self.recover_expr_ty(*base, env)?; @@ -954,9 +971,14 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { Expr::Object { type_name: Some(name), type_args, + fields, .. - } => self.recover_typed_object_ty(name, type_args), + } => self.recover_typed_object_ty(name, type_args, fields, env), Expr::Array { elements } => self.recover_array_ty(elements, env), + Expr::Block { tail_expr, .. } => { + tail_expr.and_then(|tail| self.recover_expr_ty(tail, env)) + } + Expr::Lambda(func_def) => Some(self.recover_lambda_ty(func_def, env)), Expr::OptionalChain { expr } => self.recover_expr_ty(*expr, env), _ => None, } @@ -977,7 +999,13 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { } } - fn recover_typed_object_ty(&self, name: &Name, type_args: &[TypeExpr]) -> Option { + fn recover_typed_object_ty( + &self, + name: &Name, + type_args: &[TypeExpr], + fields: &[(Name, baml_compiler2_ast::ExprId)], + env: &HashMap, + ) -> Option { let def = self.res_ctx.own_items.lookup_type(self.ns_context, name)?; match def { Definition::Class(_) => { @@ -996,10 +1024,44 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { ) }) .collect(); - Some(Ty::Class( - crate::ty::NominalTypeRef::new_with_type_args(qtn, lowered_type_args), + let class_ty = Ty::Class( + crate::ty::NominalTypeRef::new_with_type_args(qtn.clone(), lowered_type_args), TyAttr::default(), - )) + ); + if !type_args.is_empty() { + return Some(class_ty); + } + + let class_loc = self.res_ctx.lookup_class_loc(self.db, &qtn)?; + let item_tree = file_item_tree(self.db, class_loc.file(self.db)); + let class_data = &item_tree[class_loc.id(self.db)]; + if class_data.generic_params.is_empty() { + return Some(class_ty); + } + + let Ty::Class(class_name, _) = &class_ty else { + return Some(class_ty); + }; + let field_types = self.res_ctx.lookup_class_fields(self.db, class_name); + let field_pairs: Vec<(Ty, Ty)> = fields + .iter() + .filter_map(|(field_name, field_expr)| { + let declared_ty = field_types + .iter() + .find_map(|(name, ty)| (name == field_name).then(|| ty.clone()))?; + let actual_ty = self.recover_expr_ty(*field_expr, env)?.widen_fresh(); + Some((declared_ty, actual_ty)) + }) + .collect(); + + Some( + crate::generics::infer_nominal_type_args_from_fields( + &class_ty, + &class_data.generic_params, + &field_pairs, + ) + .unwrap_or(class_ty), + ) } Definition::Enum(_) => { let qtn = qualify_def(self.db, def, name); @@ -1026,7 +1088,13 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { } fn resolve_member_ty(&self, base_ty: &Ty, field: &Name) -> Option { - let resolved_base = crate::throws_semantics::resolve_alias_chain(base_ty, self.aliases); + let resolved_base = + match crate::throws_semantics::resolve_alias_chain(base_ty, self.aliases) { + Ty::Optional(inner, _) => { + crate::throws_semantics::resolve_alias_chain(inner.as_ref(), self.aliases) + } + other => other, + }; let Ty::Class(class_name, _) = &resolved_base else { return None; }; @@ -1042,6 +1110,70 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { Some(method.function.as_ty()) } + fn recover_lambda_ty( + &self, + func_def: &baml_compiler2_ast::FunctionDef, + env: &HashMap, + ) -> Ty { + let params: Vec<(Option, Ty)> = func_def + .params + .iter() + .map(|param| { + let ty = param + .type_expr + .as_ref() + .map(|te| self.lower_local_type_expr(&te.expr)) + .unwrap_or(Ty::Unknown { + attr: TyAttr::default(), + }); + (Some(param.name.clone()), ty) + }) + .collect(); + + let mut lambda_env = env.clone(); + for (name, ty) in ¶ms { + if let Some(name) = name { + lambda_env.insert(name.clone(), ty.clone()); + } + } + + let ret_ty = func_def + .return_type + .as_ref() + .map(|te| self.lower_local_type_expr(&te.expr)) + .or_else(|| match &func_def.body { + Some(baml_compiler2_ast::FunctionBodyDef::Expr(body, _)) => body + .root_expr + .and_then(|root_expr| self.recover_expr_ty(root_expr, &lambda_env)), + _ => None, + }) + .unwrap_or(Ty::Unknown { + attr: TyAttr::default(), + }); + + let throws_ty = if let Some(throws) = &func_def.throws { + self.lower_local_type_expr(&throws.expr) + } else if let Some(baml_compiler2_ast::FunctionBodyDef::Expr(body, _)) = &func_def.body { + crate::throws_semantics::concrete_throws_ty_from_facts(collect_direct_throws( + self.db, + self.res_ctx.own_items, + self.ns_context, + body, + )) + } else { + Ty::Never { + attr: TyAttr::default(), + } + }; + + Ty::Function { + params, + ret: Box::new(ret_ty), + throws: Box::new(throws_ty), + attr: TyAttr::default(), + } + } + fn named_callable_ty(&self, path: &[Name]) -> Option { let (_source, function) = self .res_ctx @@ -1135,14 +1267,10 @@ impl<'a, 'db> MemberFieldCallCollector<'a, 'db> { } fn collapse_types(types: Vec) -> Option { - let mut unique = BTreeSet::new(); - for ty in types { - unique.insert(ty); - } - match unique.len() { - 0 => None, - 1 => unique.into_iter().next(), - _ => Some(Ty::Union(unique.into_iter().collect(), TyAttr::default())), + if types.is_empty() { + None + } else { + Some(crate::generics::merge_collection_member_types(types)) } } diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info.rs b/baml_language/crates/baml_lsp2_actions/src/type_info.rs index a5f38e2164..87d0305d5f 100644 --- a/baml_language/crates/baml_lsp2_actions/src/type_info.rs +++ b/baml_language/crates/baml_lsp2_actions/src/type_info.rs @@ -375,15 +375,53 @@ fn inferred_throws_for_function( let throw_sets = baml_compiler2_tir::throw_inference::function_throw_sets(db, pkg_id); let key = baml_compiler2_tir::throw_inference::callable_throw_key(db, func_loc); - let facts = throw_sets.transitive_for(&key)?; + let pre_inference_facts = throw_sets.transitive_for(&key); + let has_effect_vars = pre_inference_facts.is_some_and(|facts| { + facts + .iter() + .any(|fact| matches!(fact, baml_compiler2_tir::ty::Ty::TypeVar(_, _))) + }); + + if !has_effect_vars + && let Some(effective_throws) = effective_throws_for_function(db, func_loc, pkg_id) + { + return match effective_throws { + baml_compiler2_tir::ty::Ty::Never { .. } => None, + other => Some(utils::display_ty(&other)), + }; + } + + let facts = pre_inference_facts?; if facts.is_empty() { return None; } - let parts: Vec = facts.iter().map(utils::display_ty).collect(); Some(parts.join(" | ")) } +fn effective_throws_for_function( + db: &dyn Db, + func_loc: baml_compiler2_hir::loc::FunctionLoc<'_>, + pkg_id: baml_compiler2_hir::package::PackageId<'_>, +) -> Option { + let body = baml_compiler2_hir::body::function_body(db, func_loc); + let baml_compiler2_hir::body::FunctionBody::Expr(expr_body) = body.as_ref() else { + return None; + }; + + let item_tree = baml_compiler2_hir::file_item_tree(db, func_loc.file(db)); + let func_data = &item_tree[func_loc.id(db)]; + let index = baml_compiler2_hir::file_semantic_index(db, func_loc.file(db)); + let (scope_idx, _) = index.scopes.iter().enumerate().find(|(_, scope)| { + scope.kind == ScopeKind::Function + && scope.range == func_data.span + && scope.name.as_ref() == Some(&func_data.name) + })?; + let scope_id = index.scope_ids[scope_idx]; + let inference = baml_compiler2_tir::inference::infer_scope_types(db, scope_id); + Some(inference.effective_throws(db, pkg_id, expr_body)) +} + fn declared_throws_for_function( db: &dyn Db, func_loc: baml_compiler2_hir::loc::FunctionLoc<'_>, diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs index d6107bc432..8cca5948a6 100644 --- a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs +++ b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs @@ -509,4 +509,120 @@ function <[CURSOR]qualified() -> null throws root.errors.Io | string { "Hover should preserve the qualified declared throws clause, got: {md}" ); } + + #[test] + fn hover_function_with_inferred_throws_from_inline_heterogeneous_collection() { + let test = CursorTest::new( + r#" +class Task { + name: string + run: () -> null throws E +} + +function <[CURSOR]run_inline_tasks() -> null { + let tasks = [ + Task { name: "a", run: () -> null { throw "error" } }, + Task { name: "b", run: () -> null { throw 42 } }, + ] + for (let task in tasks) { + task.run() + } + null +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert_eq!(throws.as_deref(), Some("int | string")); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + md.contains("throws int | string"), + "Hover should show inferred throws from inline stored callbacks, got: {md}" + ); + } + + #[test] + fn hover_local_var_preserves_generic_payload_for_inline_task_array() { + let test = CursorTest::new( + r#" +class Task { + name: string + run: () -> null throws E +} + +function run_inline_tasks() -> null { + let <[CURSOR]tasks = [ + Task { name: "a", run: () -> null { throw "error" } }, + Task { name: "b", run: () -> null { throw 42 } }, + ] + null +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::LocalVar { name, ty } => { + assert_eq!(name, "tasks"); + assert_eq!(ty, "user.Task[]"); + } + other => panic!("Expected TypeInfo::LocalVar, got: {other:?}"), + } + } + + #[test] + fn hover_function_with_caught_optional_call_omits_throws_clause() { + let test = CursorTest::new( + r#" +function <[CURSOR]optional_call_caught(cb: (() -> int throws string)?) -> int? { + cb?.() catch (e) { + _ => null + } +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert_eq!(throws, &None); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + !md.contains(") -> int? throws"), + "Caught optional callback hover should omit throws, got: {md}" + ); + } + + #[test] + fn hover_function_with_optional_method_call_shows_inferred_throws() { + let test = CursorTest::new( + r#" +class Handler { + run: () -> null throws E +} + +function <[CURSOR]maybe_use_string_handler(h: Handler?) -> null { + h?.run() + null +} +"#, + ); + let info = test.type_info().expect("should resolve"); + match &info { + TypeInfo::Function { throws, .. } => { + assert_eq!(throws.as_deref(), Some("string")); + } + other => panic!("Expected TypeInfo::Function, got: {other:?}"), + } + let md = info.to_hover_markdown(); + assert!( + md.contains("throws string"), + "Optional method call hover should show inferred throws, got: {md}" + ); + } } diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____03_hir.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____03_hir.snap index 2cef9bce8e..ccd911ba75 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____03_hir.snap @@ -21,10 +21,10 @@ function baml.length(self: ?) -> int [builtin] function baml.map(self: ?, f: (baml.T) -> baml.U) -> baml.U[] [builtin] function baml.map(self: ?, f: (baml.K, baml.V) -> baml.U) -> baml.U[] [builtin] function baml.map_keys(self: ?, f: (baml.K) -> baml.U) -> baml.U[] [expr] { - { } self.keys().map(f) + { let result: baml.U[] = []; for key in self.keys() { } result.push(f(key)) } result } function baml.map_values(self: ?, f: (baml.V) -> baml.U) -> baml.U[] [expr] { - { } self.values().map(f) + { let result: baml.U[] = []; for value in self.values() { } result.push(f(value)) } result } function baml.pop(self: ?) -> baml.T? [builtin] function baml.push(self: ?, item: baml.T) -> int [builtin] diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap index 7a3645e28e..357d3a10e4 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap @@ -29,19 +29,59 @@ fn baml.Map.map_keys(self: baml.Map, f: (void) -> void) -> void[] { let _0: void[] // _0 // return let _1: baml.Map // self // param let _2: (void) -> void // f // param - let _3: void[] + let _3: void[] // result + let _4: void[] + let _5: int // __for_idx + let _6: int + let _7: bool + let _8: void + let _9: void // key + let _10: void + let _11: void[] + let _12: void + let _13: void bb0: { - _3 = call const fn baml.Map.keys(copy _1) -> [bb1]; + _3 = []; + _4 = call const fn baml.Map.keys(copy _1) -> [bb1]; } bb1: { - _0 = call const fn baml.Array.map(copy _3, copy _2) -> [bb2]; + _5 = const 0_i64; + goto -> bb2; } bb2: { + _6 = len(_4); + _7 = copy _5 < copy _6; + branch copy _7 -> [bb5, bb3]; + } + + bb3: { + _0 = copy _3; + goto -> bb4; + } + + bb4: { return; } + + bb5: { + _8 = copy _4[_5]; + _9 = copy _8; + _11 = copy _3; + _13 = copy _9; + _12 = call copy _2(copy _13) -> [bb6]; + } + + bb6: { + _10 = call const fn baml.Array.push(copy _11, copy _12) -> [bb7]; + } + + bb7: { + _5 = copy _5 + const 1_i64; + goto -> bb2; + } } fn baml.Map.map_values(self: baml.Map, f: (void) -> void) -> void[] { @@ -49,19 +89,59 @@ fn baml.Map.map_values(self: baml.Map, f: (void) -> void) -> void[] { let _0: void[] // _0 // return let _1: baml.Map // self // param let _2: (void) -> void // f // param - let _3: void[] + let _3: void[] // result + let _4: void[] + let _5: int // __for_idx + let _6: int + let _7: bool + let _8: void + let _9: void // value + let _10: void + let _11: void[] + let _12: void + let _13: void bb0: { - _3 = call const fn baml.Map.values(copy _1) -> [bb1]; + _3 = []; + _4 = call const fn baml.Map.values(copy _1) -> [bb1]; } bb1: { - _0 = call const fn baml.Array.map(copy _3, copy _2) -> [bb2]; + _5 = const 0_i64; + goto -> bb2; } bb2: { + _6 = len(_4); + _7 = copy _5 < copy _6; + branch copy _7 -> [bb5, bb3]; + } + + bb3: { + _0 = copy _3; + goto -> bb4; + } + + bb4: { return; } + + bb5: { + _8 = copy _4[_5]; + _9 = copy _8; + _11 = copy _3; + _13 = copy _9; + _12 = call copy _2(copy _13) -> [bb6]; + } + + bb6: { + _10 = call const fn baml.Array.push(copy _11, copy _12) -> [bb7]; + } + + bb7: { + _5 = copy _5 + const 1_i64; + goto -> bb2; + } } fn baml.Array.pop = builtin(vm) diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap index 1943763966..dc967bcda1 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_tir.snap @@ -10,14 +10,24 @@ class baml.Array { } class baml.Map { } -function baml.Map.map_keys(self: baml.Map, f: (unknown) -> U) -> U[] throws never { +function baml.Map.map_keys(self: baml.Map, f: (unknown) -> U throws __throws_f) -> U[] throws __throws_f { { : U[] - self.keys().map(f) : U[] + let result = [] : U[] + for key in self.keys() + { : null + result.push(f(key)) : null + } + result : U[] } } -function baml.Map.map_values(self: baml.Map, f: (unknown) -> U) -> U[] throws never { +function baml.Map.map_values(self: baml.Map, f: (unknown) -> U throws __throws_f) -> U[] throws __throws_f { { : U[] - self.values().map(f) : U[] + let result = [] : U[] + for value in self.values() + { : null + result.push(f(value)) : null + } + result : U[] } } class baml.Array$stream { @@ -183,7 +193,7 @@ function baml.llm.render_prompt(client: baml.llm.Client, function_name: string, primitive_client.specialize_prompt(prompt) : baml.llm.PromptAst } } -function baml.llm.build_request(client: baml.llm.Client, function_name: string, args: map) -> baml.http.Request throws baml.errors.LlmClient | unknown { +function baml.llm.build_request(client: baml.llm.Client, function_name: string, args: map) -> baml.http.Request throws baml.errors.LlmClient | baml.errors.RenderPrompt | unknown { { : baml.http.Request let primitive_client = client.to_primitive_client() : baml.llm.PrimitiveClient let specialized_prompt = render_prompt(client, function_name, args) : baml.llm.PromptAst @@ -196,7 +206,7 @@ function baml.llm.parse(function_name: string, json: string) -> T throws baml __sap_parse(json, return_type) : T } } -function baml.llm.call_llm_function(client: baml.llm.Client, function_name: string, args: map) -> T throws baml.errors.InvalidArgument | unknown { +function baml.llm.call_llm_function(client: baml.llm.Client, function_name: string, args: map) -> T throws baml.errors.InvalidArgument | baml.errors.LlmClient | baml.errors.RenderPrompt | unknown { { : T let jinja_string = get_jinja_template(function_name) : string let context = ExecutionContext { jinja_string: jinja_string, args: args, function_name: function_name } : baml.llm.ExecutionContext @@ -336,7 +346,7 @@ function baml.llm.Client.build_plan_with_state(self: baml.llm.Client, planner_st self.build_attempt_with_state(planner_state) : baml.llm.OrchestrationStep[] } } -function baml.llm.Client.execute(self: baml.llm.Client, context: baml.llm.ExecutionContext, inherited_delay_ms: int) -> T throws unknown { +function baml.llm.Client.execute(self: baml.llm.Client, context: baml.llm.ExecutionContext, inherited_delay_ms: int) -> T throws baml.errors.LlmClient | baml.errors.RenderPrompt | unknown { { : never match (self.retry : baml.llm.RetryPolicy?) : never r: RetryPolicy => diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap index 1d5c46fac7..b396171774 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap @@ -44,19 +44,81 @@ function baml.Map.map(self: null, f: (void, void) -> void) -> void[] { } function baml.Map.map_keys(self: null, f: (void) -> void) -> void[] { + alloc_array 0 + store_var result load_var self call baml.Map.keys - load_var f - call baml.Array.map + store_var _4 + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var _4 + call baml.Array.length + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_var result return + + L2: + load_var _4 + load_var __for_idx + load_array_element + load_var f + call_indirect + store_var _12 + load_var result + load_var _12 + call baml.Array.push + pop 1 + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 } function baml.Map.map_values(self: null, f: (void) -> void) -> void[] { + alloc_array 0 + store_var result load_var self call baml.Map.values - load_var f - call baml.Array.map + store_var _4 + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var _4 + call baml.Array.length + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_var result return + + L2: + load_var _4 + load_var __for_idx + load_array_element + load_var f + call_indirect + store_var _12 + load_var result + load_var _12 + call baml.Array.push + pop 1 + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 } function baml.Map.set(self: null, key: void, value: void) -> null { diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap index 0ebb68fbfa..56ed5bf2eb 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_5_mir.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 13184 --- === MIR2 === fn user.test_array_map_pure() -> int[] { @@ -1651,7 +1650,7 @@ fn user.maybe_use_string_handler(h: Handler?) -> null { // Locals: let _0: null // _0 // return let _1: Handler? // h // param - let _2: void + let _2: null let _3: (() -> null throws string)? let _4: bool @@ -1926,7 +1925,7 @@ fn user.run_inline_tasks() -> null { let _10: Task let _11: Task // task let _12: void - let _13: () -> null + let _13: () -> null throws string | int let _14: Task bb0: { diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap index 6e5358ae85..4bc6037cfe 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__04_tir.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 12944 --- === TIR2 === function user.test_array_map_pure() -> int[] throws never { @@ -714,10 +713,9 @@ function user.use_local_string_handler() -> null throws string { } function user.maybe_use_string_handler(h: user.Handler?) -> null throws string { { : null - h?.run() : unknown + h?.run() : null null : null } - !! 560..570: `(() -> null throws string)?` is not a function — it cannot be called } function user.make_string_handler() -> user.Handler throws never { { : user.Handler @@ -753,7 +751,7 @@ function user.test_local_string_handler() -> null throws string { use_local_string_handler() : null } } -function user.test_optional_string_handler() -> null throws never { +function user.test_optional_string_handler() -> null throws string { { : null maybe_use_string_handler(make_string_handler()) : null } @@ -820,9 +818,9 @@ function user.run_local_tasks() -> null throws int | string { null : null } } -function user.run_inline_tasks() -> null throws never { +function user.run_inline_tasks() -> null throws int | string { { : null - let tasks = [Task { name: "a", run: () -> null { ... } }, Task { name: "b", run: () -> null { ... } }] : user.Task[] + let tasks = [Task { name: "a", run: () -> null { ... } }, Task { name: "b", run: () -> null { ... } }] : user.Task[] for task in tasks { : null task.run() : null @@ -849,7 +847,7 @@ function user.test_run_local_tasks() -> null throws int | string { run_local_tasks() : null } } -function user.test_run_inline_tasks() -> null throws never { +function user.test_run_inline_tasks() -> null throws int | string { { : null run_inline_tasks() : null } diff --git a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap index bb35df7c94..f022620bf2 100644 --- a/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/function_type_throws/baml_tests__function_type_throws__05_diagnostics.snap @@ -329,16 +329,6 @@ source: crates/baml_tests/src/generated_tests.rs │ Note: Error code: E0001 ───╯ - [type] Error: `(() -> null throws string)?` is not a function — it cannot be called - ╭─[ generic_stored_callback.baml:27:1 ] - │ - 27 │ h?.run() - │ ─────┬──── - │ ╰────── `(() -> null throws string)?` is not a function — it cannot be called - │ - │ Note: Error code: E0001 -────╯ - [type] Error: type mismatch: expected user.Box, got user.Box ╭─[ generic_subtype_rejection.baml:14:19 ] │ diff --git a/baml_language/crates/baml_tests/src/compiler2_tir/phase6.rs b/baml_language/crates/baml_tests/src/compiler2_tir/phase6.rs index 16ae04c3b8..e0523ad132 100644 --- a/baml_language/crates/baml_tests/src/compiler2_tir/phase6.rs +++ b/baml_language/crates/baml_tests/src/compiler2_tir/phase6.rs @@ -257,6 +257,34 @@ fn let_inferred_from_map_keys() { "); } +#[test] +fn map_keys_rethrows_callback_throws() { + let mut db = make_db(); + let file = db.add_file( + "test.baml", + r#"function f(m: map) -> int[] { return m.map_keys((k: string) -> int { throw "boom" }); }"#, + ); + let tir = render_tir(&db, file); + assert!( + tir.contains("function user.f(m: map) -> int[] throws string"), + "expected map_keys to propagate lambda throws, got:\n{tir}" + ); +} + +#[test] +fn map_values_rethrows_callback_throws() { + let mut db = make_db(); + let file = db.add_file( + "test.baml", + r#"function f(m: map) -> string[] { return m.map_values((v: int) -> string { throw "boom" }); }"#, + ); + let tir = render_tir(&db, file); + assert!( + tir.contains("function user.f(m: map) -> string[] throws string"), + "expected map_values to propagate lambda throws, got:\n{tir}" + ); +} + // ── Media type method resolution ────────────────────────────────────────────── #[test]