From 5dfa1b8d472a0b81ae9fc2337f4f4d3a9c3176f4 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 27 May 2026 16:32:05 +0800 Subject: [PATCH 01/48] Add `PrimitiveType`, `ErasedType`, and refactor `Value.Lit` --- .../src/main/scala/hkmc2/codegen/Block.scala | 81 +++++++++++++++++-- .../scala/hkmc2/codegen/BlockSimplifier.scala | 38 ++++----- .../hkmc2/codegen/BlockTransformer.scala | 2 +- .../scala/hkmc2/codegen/BlockTraverser.scala | 2 +- .../hkmc2/codegen/BufferableTransform.scala | 6 +- .../scala/hkmc2/codegen/DeadParamElim.scala | 2 +- .../scala/hkmc2/codegen/HandlerLowering.scala | 34 ++++---- .../src/main/scala/hkmc2/codegen/Lifter.scala | 2 +- .../main/scala/hkmc2/codegen/Lowering.scala | 38 ++++----- .../main/scala/hkmc2/codegen/Printer.scala | 2 +- .../codegen/ReflectionInstrumenter.scala | 6 +- .../hkmc2/codegen/SpecializedSwitch.scala | 2 +- .../hkmc2/codegen/StackSafeTransform.scala | 2 +- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 12 +-- .../hkmc2/codegen/deforest/Rewrite.scala | 2 +- .../codegen/flowAnalysis/FlowAnalysis.scala | 6 +- .../scala/hkmc2/codegen/js/JSBuilder.scala | 16 ++-- .../scala/hkmc2/codegen/llir/Builder.scala | 10 +-- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 14 ++-- .../hkmc2/semantics/ucs/Normalization.scala | 2 +- .../src/test/mlscript/codegen/BasicTerms.mls | 12 ++- .../test/mlscript/codegen/BlockPrinter.mls | 8 +- 22 files changed, 187 insertions(+), 112 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 43b7cc95e4..2565dc9245 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -12,7 +12,7 @@ import hkmc2.semantics.{Term => st} import syntax.{Literal, Tree, SpreadKind, Keyword} import semantics.* import semantics.Term.* -import sem.Elaborator.State +import sem.Elaborator.{Ctx, State, ctx} /* Important design notes. @@ -846,6 +846,51 @@ enum Case: case Tup(_, _) => Set.empty case Field(_, _) => Set.empty +/** A primitive type of the block IR. */ +enum PrimitiveType: + case Unit, Int, Int31, Num, Str, Bool + + /** The symbol for this primitive type, if available. */ + def sym(using Ctx, State): Opt[ClassLikeSymbol] = this match + case Unit => S(summon[State].unitSymbol) + case Int => S(ctx.builtins.Int) + case Int31 => S(ctx.builtins.Int31) + case Num => S(ctx.builtins.Num) + case Str => S(ctx.builtins.Str) + case Bool => S(ctx.builtins.Bool) + + +object ErasedType: + def objectRef(using Ctx): ErasedType = ErasedType.AnyRef(rsc = false, ctx.builtins.Object) + +/** A generics-erased type of the Block IR. */ +enum ErasedType: + /** + * An reference to a class-like symbol. + * + * - `rsc` is true if this reference is a resource class. + */ + case AnyRef(rsc: Bool, csym: ClassLikeSymbol) + + /** An primitive type. */ + case Primitive(prim: PrimitiveType) + +/** Trait representing a Block IR element that has an [[`ErasedType`]]. */ +trait HasErasedType: + /** The [[`ErasedType`]] of this element, or `N` if the erased type is not known. */ + def erasedType: Opt[ErasedType] + + /** Similar to `erasedType`, but coerces to the top type if the specific erased type is not known. */ + def erasedType_!(using Ctx): ErasedType = erasedType.getOrElse(ErasedType.objectRef) + +extension (lit: Literal) + def erasedType: ErasedType = lit match + case Tree.UnitLit(_) => ErasedType.Primitive(PrimitiveType.Unit) + case Tree.IntLit(_) => ErasedType.Primitive(PrimitiveType.Int) + case Tree.DecLit(_) => ErasedType.Primitive(PrimitiveType.Num) + case Tree.StrLit(_) => ErasedType.Primitive(PrimitiveType.Str) + case Tree.BoolLit(_) => ErasedType.Primitive(PrimitiveType.Bool) + sealed trait TrivialResult extends Result sealed abstract class Result extends AutoLocated: @@ -857,7 +902,7 @@ sealed abstract class Result extends AutoLocated: case Value.SimpleRef(l) => l.showAsPlain case Value.MemberRef(l, disamb) => s"${l.showAsPlain}${s"‹${disamb.showAsPlain}›"}" case Value.This(sym) => s"this[${sym.showAsPlain}]" - case Value.Lit(lit) => lit.idStr + case Value.Lit(lit, _) => lit.idStr case Select(q, n) => s"Select(${q.showDbg}, ${n.showDbg})" case DynSelect(q, fld, arrayIdx) => s"DynSelect(${q.showDbg}, ${fld.showDbg}, $arrayIdx)" case Call(fun, argss) => s"Call(${fun.showDbg}, [${ @@ -896,7 +941,7 @@ sealed abstract class Result extends AutoLocated: case Value.SimpleRef(l) => Vector.empty case Value.MemberRef(bms, disamb) => Vector.empty case Value.This(sym) => Vector.empty - case Value.Lit(lit) => Vector.single(lit) + case Value.Lit(lit, _) => Vector.single(lit) // TODO rm Lam from values and thus the need for this method def subBlocks: Ls[Block] = this match @@ -918,7 +963,7 @@ sealed abstract class Result extends AutoLocated: case Value.SimpleRef(l) => Set(l) case Value.MemberRef(bms, _) => Set(bms) case Value.This(sym) => Set.empty - case Value.Lit(lit) => Set.empty + case Value.Lit(lit, _) => Set.empty case DynSelect(qual, fld, arrayIdx) => qual.freeVars ++ fld.freeVars lazy val freeVarsLLIR: Set[Local] = this match @@ -941,7 +986,7 @@ sealed abstract class Result extends AutoLocated: case Some(d: TermDefinition) if d.companionClass.isDefined => Set.empty case _ => Set(l) case Value.This(sym) => Set.empty - case Value.Lit(lit) => Set.empty + case Value.Lit(lit, _) => Set.empty case DynSelect(qual, fld, arrayIdx) => qual.freeVarsLLIR ++ fld.freeVarsLLIR lazy val size: Int = this match @@ -952,8 +997,8 @@ sealed abstract class Result extends AutoLocated: case Tuple(mut, elems) => elems.iterator.map(_.value.size).sum case Record(mut, args) => args.iterator.map(arg => arg.idx.fold(0)(_.size) + arg.value.size).sum case _: Value.RefLike => 0 - case Value.Lit(l: Tree.StrLit) => l.value.length / 4 - case Value.Lit(lit) => 0 + case Value.Lit(l: Tree.StrLit, _) => l.value.length / 4 + case Value.Lit(lit, _) => 0 case DynSelect(qual, fld, arrayIdx) => qual.size + fld.size // * TODO: refine this very loose type @@ -1009,7 +1054,7 @@ enum Value extends Path with ProductWithExtraInfo: */ case MemberRef(bms: BlockMemberSymbol, disamb: DefinitionSymbol[?]) case This(sym: InnerSymbol) - case Lit(lit: Literal) + case Lit(lit: Literal, erasedType: ErasedType) override def extraInfo(using DebugPrinter): Str = this match case MemberRef(bms, disamb) => s"disamb=${disamb.showAsPlain}" @@ -1028,6 +1073,26 @@ object Value: case MemberRef(bms, _) => bms case This(sym) => sym + object IntLit: + def apply(i: BigInt): Value.Lit = + Value.Lit(Tree.IntLit(i), ErasedType.Primitive(PrimitiveType.Int)) + + object DecLit: + def apply(d: BigDecimal): Value.Lit = + Value.Lit(Tree.DecLit(d), ErasedType.Primitive(PrimitiveType.Num)) + + object StrLit: + def apply(s: Str): Value.Lit = + Value.Lit(Tree.StrLit(s), ErasedType.Primitive(PrimitiveType.Str)) + + object UnitLit: + def apply(isNullNotUndefined: Bool): Value.Lit = + Value.Lit(Tree.UnitLit(isNullNotUndefined), ErasedType.Primitive(PrimitiveType.Unit)) + + object BoolLit: + def apply(b: Bool): Value.Lit = + Value.Lit(Tree.BoolLit(b), ErasedType.Primitive(PrimitiveType.Bool)) + @deprecated("Use Value.SimpleRef, Value.MemberRef, or Value.This instead.") object Ref: def apply(l: Local, disamb: Opt[DefinitionSymbol[?]]): Value.RefLike = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala index a4b991ad13..d052230733 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala @@ -219,7 +219,7 @@ class BlockSimplifier case Value.SimpleRef(loc) if localVars.contains(loc) && !definedVars.contains(loc) => registerChange(s"${loc.showDbg} is never assigned; replacing read with undefined") // if !symbolsToPreserve(loc) then removedLocals += loc - k(Value.Lit(syntax.Tree.UnitLit(false))) + k(Value.UnitLit(false)) case _ => super.applyValue(v)(k) override def applyBlock(b: Block): Block = b match @@ -649,7 +649,7 @@ class BlockSimplifier assignedResults.get(r).fold(giveUp)(getShapesA) case Value.MemberRef(r, sym: ModuleOrObjectSymbol) => Set.single(sym) - case Value.Lit(lit) => Set.single(lit) + case Value.Lit(lit, _) => Set.single(lit) case _ => giveUp var shapes = if deadBranchRemoval then getShapes(scrut2) else giveUp @@ -760,7 +760,7 @@ class BlockSimplifier if litValue =/= false then ass.rhs match - case v @ Value.Lit(lit) => + case v @ Value.Lit(lit, _) => if litValue === true then litValue = v else if litValue =/= v then @@ -793,7 +793,7 @@ class BlockSimplifier litValue match case true => registerChange(s"${loc.showDbg} ~> undefined") - return k(Value.Lit(syntax.Tree.UnitLit(false))) + return k(Value.UnitLit(false)) case lit: Value => registerChange(s"${loc.showDbg} ~> ${lit.showDbg}") return k(lit) @@ -874,22 +874,22 @@ class BlockSimplifier // TODO: mv to smart ctor of Call import syntax.Tree.*, Value.Lit val builtinEval: PartialFunction[(Str, List[Value]), Value] = - case ("+", (lit @ Lit(IntLit(v1))) :: Nil) => lit - case ("+", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(IntLit(v1 + v2)) - case ("-", Lit(IntLit(v1)) :: Nil) => Lit(IntLit(-v1)) - case ("-", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(IntLit(v1 - v2)) - case ("*", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(IntLit(v1 * v2)) + case ("+", (lit @ Lit(IntLit(v1), _)) :: Nil) => lit + case ("+", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.IntLit(v1 + v2) + case ("-", Lit(IntLit(v1), _) :: Nil) => Value.IntLit(-v1) + case ("-", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.IntLit(v1 - v2) + case ("*", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.IntLit(v1 * v2) // * For "/", should check for 0 and return a DecLit - case ("%", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(IntLit(v1 % v2)) - case ("===", Lit(l1) :: Lit(l2) :: Nil) => Lit(BoolLit(l1 == l2)) - case ("!==", Lit(l1) :: Lit(l2) :: Nil) => Lit(BoolLit(l1 != l2)) - case ("<", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(BoolLit(v1 < v2)) - case ("<=", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(BoolLit(v1 <= v2)) - case (">", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(BoolLit(v1 > v2)) - case (">=", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(BoolLit(v1 >= v2)) - case ("&&", Lit(BoolLit(v1)) :: Lit(BoolLit(v2)) :: Nil) => Lit(BoolLit(v1 && v2)) - case ("||", Lit(BoolLit(v1)) :: Lit(BoolLit(v2)) :: Nil) => Lit(BoolLit(v1 || v2)) - case ("!", Lit(BoolLit(v)) :: Nil) => Lit(BoolLit(!v)) + case ("%", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.IntLit(v1 % v2) + case ("===", Lit(l1, _) :: Lit(l2, _) :: Nil) => Value.BoolLit(l1 == l2) + case ("!==", Lit(l1, _) :: Lit(l2, _) :: Nil) => Value.BoolLit(l1 != l2) + case ("<", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.BoolLit(v1 < v2) + case ("<=", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.BoolLit(v1 <= v2) + case (">", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.BoolLit(v1 > v2) + case (">=", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.BoolLit(v1 >= v2) + case ("&&", Lit(BoolLit(v1), _) :: Lit(BoolLit(v2), _) :: Nil) => Value.BoolLit(v1 && v2) + case ("||", Lit(BoolLit(v1), _) :: Lit(BoolLit(v2), _) :: Nil) => Value.BoolLit(v1 || v2) + case ("!", Lit(BoolLit(v), _) :: Nil) => Value.BoolLit(!v) end DataFlowAnalysis diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala index c30ebbddfc..9668580742 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala @@ -191,7 +191,7 @@ class BlockTransformer(subst: SymbolSubst): case Value.This(sym) => val sym2 = sym.subst k(if (sym2 is sym) then v else sym2.asThis.withLocOf(v)) - case Value.Lit(lit) => k(v) + case Value.Lit(lit, _) => k(v) def applyLocal(sym: Local): Local = sym.subst diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala index 24f44aa1ee..5f66bb41fb 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala @@ -76,7 +76,7 @@ class BlockTraverser: bms.traverse disamb.traverse case Value.This(sym) => sym.traverse - case Value.Lit(lit) => () + case Value.Lit(lit, _) => () def applyLocal(sym: Local): Unit = sym.traverse diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala index 0ab44b19db..946b16af21 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala @@ -38,11 +38,11 @@ class BufferableTransform()(using Ctx, State, Raise): def mkFieldReplacer(buf: VarSymbol, baseIdx: VarSymbol, symMap: Map[Symbol, Symbol]) = def getOffset(off: Int)(k: Path => Block): Block = val idxSymbol = new TempSymbol(N, "idx") - Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), + Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.IntLit(off).asArg :: Nil) ne_:: Nil)(true, false, false), k(DynSelect(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true)))) def assignToOffset(off: Int, r: Result, rst: Block) = val idxSymbol = new TempSymbol(N, "idx") - Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), + Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.IntLit(off).asArg :: Nil) ne_:: Nil)(true, false, false), AssignDynField(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true, r, applyBlock(rst)))) new BlockTransformer(SymbolSubst.Id): override def applyLocal(sym: Local): Local = symMap.getOrElse(sym, sym) @@ -91,7 +91,7 @@ class BufferableTransform()(using Ctx, State, Raise): fakeCtor :: cls.methods.map(transformFunDefn(_, false)), Nil, clsSizeSym -> clsSizeTermSym :: Nil, - Define(ValDefn(clsSizeTermSym, clsSizeSym, Value.Lit(Tree.IntLit(fields.size)))(N, Nil), End()), + Define(ValDefn(clsSizeTermSym, clsSizeSym, Value.IntLit(fields.size))(N, Nil), End()), annotations = Nil, ) k: diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala index 2b75982a87..cd6960f8db 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala @@ -238,7 +238,7 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise): override def applyValue(v: Value)(k: Value => Block): Block = v match case ref@Value.SimpleRef(l: VarSymbol) if activeEliminatedParams(l) => - k(Value.Lit(Tree.UnitLit(false)).withLocOf(ref)) + k(Value.UnitLit(false).withLocOf(ref)) case _ => super.applyValue(v)(k) override def applyBlock(b: Block): Block = b match diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index 11263f63b2..0b34e1adc2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -24,12 +24,12 @@ object HandlerLowering: private val nextIdent: Tree.Ident = Tree.Ident("next") private val lastIdent: Tree.Ident = Tree.Ident("last") private val contTraceIdent: Tree.Ident = Tree.Ident("contTrace") - private def unit = Value.Lit(Tree.UnitLit(true)) - private def intLit(i: BigInt) = Value.Lit(Tree.IntLit(i)) + private def unit = Value.UnitLit(true) + private def intLit(i: BigInt) = Value.IntLit(i) private def locToStr(loc: Loc) = val (line, _, col) = loc.origin.fph.getLineColAt(loc.spanStart) - Value.Lit(Tree.StrLit(s"${loc.origin.fileName.last}:${line + loc.origin.startLineNum - 1}:$col")) + Value.StrLit(s"${loc.origin.fileName.last}:${line + loc.origin.startLineNum - 1}:$col") extension (p: Path) def pc = p.selN(pcIdent) @@ -121,7 +121,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, private def rtThrowMsg(msg: Str) = Throw( Instantiate(mut = false, State.globalThisSymbol.asThis.selN(Tree.Ident("Error")), - (Value.Lit(Tree.StrLit(msg)).asArg :: Nil) :: Nil) + (Value.StrLit(msg).asArg :: Nil) :: Nil) ) object PureCall: @@ -138,18 +138,18 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, object StateTransition: private val transitionSymbol = freshTmp("transition") def apply(uid: StateId) = - Return(PureCall(transitionSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid))))) + Return(PureCall(transitionSymbol.asSimpleRef, List(Value.IntLit(uid)))) def unapply(blk: Block) = blk match - case Return(PureCall(Value.SimpleRef(`transitionSymbol`), List(Value.Lit(Tree.IntLit(uid))))) => + case Return(PureCall(Value.SimpleRef(`transitionSymbol`), List(Value.Lit(Tree.IntLit(uid), _)))) => S(uid) case _ => N object Unwind: private val unwindSymbol = freshTmp("unwind") def apply(uid: StateId, loc: Value) = - Return(PureCall(unwindSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid)), loc))) + Return(PureCall(unwindSymbol.asSimpleRef, List(Value.IntLit(uid), loc))) def unapply(blk: Block) = blk match - case Return(PureCall(Value.SimpleRef(`unwindSymbol`), List(Value.Lit(Tree.IntLit(uid)), loc: Value))) => + case Return(PureCall(Value.SimpleRef(`unwindSymbol`), List(Value.Lit(Tree.IntLit(uid), _), loc: Value))) => S(uid, loc) case _ => N @@ -533,9 +533,9 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case Scoped(syms, body) => syms case _ => Set() val varList = scopedVars.toList.sortBy(_.uid) - val debugInfo = Value.Lit(Tree.StrLit(debugNme)).asArg :: varList.zipWithIndex.filter(_._1.isInstanceOf[VarSymbol]) + val debugInfo = Value.StrLit(debugNme).asArg :: varList.zipWithIndex.filter(_._1.isInstanceOf[VarSymbol]) .flatMap: (sym, idx) => - List(intLit(idx), Value.Lit(Tree.StrLit(sym.nme))) + List(intLit(idx), Value.StrLit(sym.nme)) .map(_.asArg) val debugInfoSym = freshTmp(s"$debugNme$$debugInfo") // TODO: properly support spread argument by calculating the correct length. @@ -595,12 +595,12 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case _ => super.applyDefn(defn)(k) val b = preTransform.applyBlock(blk) if h.inCtor then - return translateIllegalEffectCtx(b, Call(paths.illegalEffectPath, (Value.Lit(Tree.StrLit("in a constructor")).asArg :: Nil) ne_:: Nil)(true, true, false)) + return translateIllegalEffectCtx(b, Call(paths.illegalEffectPath, (Value.StrLit("in a constructor").asArg :: Nil) ne_:: Nil)(true, true, false)) if h.inTopLevel then - return translateIllegalEffectCtx(b, Call(paths.topLevelEffectPath, (Value.Lit(Tree.BoolLit(opt.debug)).asArg :: Nil) ne_:: Nil)(true, false, false)) + return translateIllegalEffectCtx(b, Call(paths.topLevelEffectPath, (Value.BoolLit(opt.debug).asArg :: Nil) ne_:: Nil)(true, false, false)) val ctx = h.asInstanceOf[HandlerCtx.FunctionLike].ctx if ctx.inGetter then - return translateIllegalEffectCtx(b, Call(paths.illegalEffectPath, (Value.Lit(Tree.StrLit("in a getter")).asArg :: Nil) ne_:: Nil)(true, false, false)) + return translateIllegalEffectCtx(b, Call(paths.illegalEffectPath, (Value.StrLit("in a getter").asArg :: Nil) ne_:: Nil)(true, false, false)) given FunctionCtx = ctx val parts = partitionBlock(b) stackSafetyMap += ctx.resumeInfo.currentStackSafetySym -> @@ -621,7 +621,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val segmentTailTransform = new BlockTransformerShallow(SymbolSubst.Id): override def applyBlock(b: Block) = b match case StateTransition(uid) => - Assign(pcVar, Value.Lit(Tree.IntLit(uid)), Continue(mainLoopLbl)) + Assign(pcVar, Value.IntLit(uid), Continue(mainLoopLbl)) case Unwind(uid, loc) => ctx.doUnwind(loc, uid, vars)(using paths) case _ => super.applyBlock(b) @@ -641,7 +641,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case StateTransition(uid) => assert(uid === nextState) if isSimple then - Assign(pcVar, Value.Lit(Tree.IntLit(uid)), End()) + Assign(pcVar, Value.IntLit(uid), End()) else Break(lblSym) case Unwind(uid, loc) => @@ -651,7 +651,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, if isSimple then transformed else Label( lblSym, false, transformed, - Assign(pcVar, Value.Lit(Tree.IntLit(nextState)), End()) + Assign(pcVar, Value.IntLit(nextState), End()) ) line match case head :: next => @@ -711,7 +711,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, Case.Lit(Tree.IntLit(-1)) -> Assign(pcVar, intLit(parts.entry), End()) :: Nil, S(restoreVars - .assignFieldN(paths.runtimePath, new Tree.Ident("resumePc"), Value.Lit(Tree.IntLit(-1))).end), + .assignFieldN(paths.runtimePath, new Tree.Ident("resumePc"), Value.IntLit(-1)).end), mainLoop)) private def translateCtorLike(b: Block, thisPath: Path, isModCtor: Bool)(using h: HandlerCtx): Block = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index 063d784c4d..6b38de5680 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -648,7 +648,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): protected final def addExtraSyms(b: Block, captureSym: Local, objSyms: Iterable[Local], define: Bool): Block = if hasCapture then - val undef = Value.Lit(Tree.UnitLit(false)).asArg + val undef = Value.UnitLit(false).asArg val inst = Instantiate( true, captureClass.sym.asMemberRef(captureClass.isym), diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index c10f08baa8..3ad4c1429d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -29,7 +29,7 @@ object Ret extends TailOp: object ImplctRet extends TailOp: def apply(r: Result): Block = r match - case Value.Lit(Tree.UnitLit(false)) => End() + case Value.Lit(Tree.UnitLit(false), _) => End() case _ => Return(r) object Thrw extends TailOp: def apply(r: Result): Block = Throw(r) @@ -75,7 +75,7 @@ import LoweringCtx.loweringCtx object Lowering: def compError: Block = - Throw(Value.Lit(Tree.StrLit("This code cannot be run as its compilation yielded an error."))) + Throw(Value.StrLit("This code cannot be run as its compilation yielded an error.")) def fail(err: ErrorReport)(using Raise): Block = raise(err) @@ -636,7 +636,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case st.UnitVal() => k(unit) case st.Lit(lit) => if lit =/= Tree.UnitLit(false) then warnStmt - k(Value.Lit(lit)) + k(Value.Lit(lit, lit.erasedType)) case st.Ret(res) => returnedTerm(res) case st.Throw(res) => @@ -679,7 +679,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case st.Break(label, result, value) => value match case S(v) => term(v)(r => Assign(result, r, Break(label))) - case N => Assign(result, Value.Lit(Tree.UnitLit(false)), Break(label)) + case N => Assign(result, Value.UnitLit(false), Break(label)) case st.Continue(label) => Continue(label) case st.Asc(lhs, rhs) => @@ -724,14 +724,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): if k.isInstanceOf[TailOp] then Match( ar1, (Case.Lit(posLit) -> term_nonTail(arg2)(k)) :: Nil, - S(k(Value.Lit(negLit))), + S(k(Value.Lit(negLit, negLit.erasedType))), Unreachable("tail operation in branches"), ) else val ts = loweringCtx.registerTempSymbol(N) Match( ar1, (Case.Lit(posLit) -> term_nonTail(arg2)(Assign(ts, _, End()))) :: Nil, - S(Assign(ts, Value.Lit(negLit), End())), + S(Assign(ts, Value.Lit(negLit, negLit.erasedType), End())), k(ts.asSimpleRef), ) sym match @@ -1010,10 +1010,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): def setupSymbol(symbol: Local)(k: Result => Block)(using LoweringCtx): Block = k(Instantiate(mut = false, State.termSymbol.asSimpleRef.selSN("Symbol"), - (Value.Lit(Tree.StrLit(symbol.nme)).asArg :: Nil) :: Nil)) + (Value.StrLit(symbol.nme).asArg :: Nil) :: Nil)) def quotePattern(p: FlatPattern)(k: Result => Block)(using LoweringCtx): Block = p match - case FlatPattern.Lit(lit) => setupTerm("LitPattern", Value.Lit(lit) :: Nil)(k) + case FlatPattern.Lit(lit) => setupTerm("LitPattern", Value.Lit(lit, lit.erasedType) :: Nil)(k) case _ => // TODO fail: ErrorReport( @@ -1050,10 +1050,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): def quote(t: st)(k: Result => Block)(using LoweringCtx): Block = t match case Lit(lit) => - setupTerm("Lit", Value.Lit(lit) :: Nil)(k) + setupTerm("Lit", Value.Lit(lit, lit.erasedType) :: Nil)(k) case Ref(sym) if Elaborator.binaryOps.contains(sym.nme) => // builtin symbols val l = loweringCtx.registerTempSymbol(N) - setupTerm("Builtin", Value.Lit(Tree.StrLit(sym.nme)) :: Nil)(k) + setupTerm("Builtin", Value.StrLit(sym.nme) :: Nil)(k) case Resolved(Ref(sym), disamb) => sym match case sym: BlockMemberSymbol => k(sym.asMemberRef(disamb)) @@ -1063,8 +1063,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case SynthSel(Ref(sym: ModuleOrObjectSymbol), name) => // Local cross-stage references setupSymbol(sym): r1 => val l1, l2 = loweringCtx.registerTempSymbol(N) - Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.Lit(syntax.Tree.UnitLit(false)) :: Nil)(r2 => - Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) + Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.UnitLit(false) :: Nil)(r2 => + Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.StrLit(name.name) :: Nil)(k)) )) case SynthSel(Ref(sym: BlockMemberSymbol), name) => // Multi-file cross-stage references if config.qqEnabled then fail: @@ -1079,8 +1079,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val basePath = base.up val targetPath = filename val relPath = targetPath.relativeTo(basePath).map(_.toString).getOrElse(targetPath.toString) - Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.Lit(syntax.Tree.StrLit(relPath)) :: Nil)(r2 => - Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) + Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.StrLit(relPath) :: Nil)(r2 => + Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.StrLit(name.name) :: Nil)(k)) )) case _ => fail: ErrorReport( @@ -1399,7 +1399,7 @@ trait LoweringSelSanityChecks(using Config, TL, Raise, State) .ifthen(selRes.asSimpleRef, Case.Lit(syntax.Tree.UnitLit(false)), Throw(Instantiate(mut = false, Select(State.globalThisSymbol.asThis, Tree.Ident("Error"))(N), - (Value.Lit(syntax.Tree.StrLit(s"Access to required field '${nme.name}' yielded 'undefined'")).asArg :: Nil) :: Nil)) + (Value.StrLit(s"Access to required field '${nme.name}' yielded 'undefined'").asArg :: Nil) :: Nil)) ) .rest(k(selRes.asSimpleRef)) @@ -1455,10 +1455,10 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) val resInspectedSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogResInspected") - val psSymArgs = psInspectedSyms.zipWithIndex.foldRight[Ls[Arg]](Arg(N, Value.Lit(Tree.StrLit(")"))) :: Nil): + val psSymArgs = psInspectedSyms.zipWithIndex.foldRight[Ls[Arg]](Arg(N, Value.StrLit(")")) :: Nil): case (((s, p), i), acc) => if i == psInspectedSyms.length - 1 then Arg(N, s.asSimpleRef) :: acc - else Arg(N, s.asSimpleRef) :: Arg(N, Value.Lit(Tree.StrLit(", "))) :: acc + else Arg(N, s.asSimpleRef) :: Arg(N, Value.StrLit(", ")) :: acc val tmp1, tmp2, tmp3 = loweringCtx.registerTempSymbol(N) @@ -1468,7 +1468,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) assignStmts( enterMsgSym -> pureCall( strConcatFn, - Arg(N, Value.Lit(Tree.StrLit(s"CALL ${name.getOrElse("[arrow function]")}("))) :: psSymArgs + Arg(N, Value.StrLit(s"CALL ${name.getOrElse("[arrow function]")}(")) :: psSymArgs ), tmp1 -> pureCall(traceLogFn, Arg(N, enterMsgSym.asSimpleRef) :: Nil), prevIndentLvlSym -> pureCall(traceLogIndentFn, Nil) @@ -1479,7 +1479,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) resInspectedSym -> pureCall(inspectFn, Arg(N, resSym.asSimpleRef) :: Nil), retMsgSym -> pureCall( strConcatFn, - Arg(N, Value.Lit(Tree.StrLit("=> "))) :: Arg(N, resInspectedSym.asSimpleRef) :: Nil + Arg(N, Value.StrLit("=> ")) :: Arg(N, resInspectedSym.asSimpleRef) :: Nil ), tmp2 -> pureCall(traceLogResetFn, Arg(N, prevIndentLvlSym.asSimpleRef) :: Nil), tmp3 -> pureCall(traceLogFn, Arg(N, retMsgSym.asSimpleRef) :: Nil) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index 7b80670ce8..b5dfc905f8 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -168,7 +168,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case Value.MemberRef(bms, disamb) => showSymbol(bms.nme, S(disamb)) case Value.This(sym) if sym === State.globalThisSymbol => showSymbol(sym.nme, S(sym.asDefnSym)) case Value.This(sym) => doc"${print(sym)}.this" - case Value.Lit(lit) => doc"${lit.idStr}" + case Value.Lit(lit, _) => doc"${lit.idStr}" def print(path: Path)(using Scope): Document = path match case sel @ Select(qual, name) => diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala index 62c13f56ea..dec706e902 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala @@ -38,7 +38,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n case b: Bool => Tree.BoolLit(b) case s: Str => Tree.StrLit(s) case n: BigDecimal => Tree.DecLit(n) - Value.Lit(l) + Value.Lit(l, l.erasedType) extension [A, B](ls: Ls[(A => B) => B]) def collectApply(f: Ls[A] => B): B = @@ -213,7 +213,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n raise(ErrorReport(msg"Instantiate with multiple argument lists not supported in staged module." -> r.toLoc :: Nil)) End() // desugar Runtime.Tuple.get into Select - case Call(fun, Ls(Arg(_, scrut), Arg(_, Value.Lit(Tree.IntLit(idx)))) :: _) if fun == Value.SimpleRef(State.runtimeSymbol).selSN("Tuple").selSN("get") => + case Call(fun, Ls(Arg(_, scrut), Arg(_, Value.Lit(Tree.IntLit(idx), _))) :: _) if fun == Value.SimpleRef(State.runtimeSymbol).selSN("Tuple").selSN("get") => transformPath(Select(scrut, Tree.Ident(idx.toString()))(N))(k) case Call(fun, argss) => val stagedFunPath = fun match @@ -261,7 +261,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n transformOption(pOpt, transformParamList)(k) def transformCase(cse: Case)(using Context)(k: Path => Block): Block = cse match - case Case.Lit(lit) => blockCtor("Lit", Ls(Value.Lit(lit)))(k) + case Case.Lit(lit) => blockCtor("Lit", Ls(Value.Lit(lit, lit.erasedType)))(k) case Case.Cls(cls, path) => transformSymbol(cls): cls => transformPath(path): path => diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala index 55b88d3c89..5077bab0df 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala @@ -118,7 +118,7 @@ private object PostCondAnalysisImpl extends CachedAnalysis[Block, PostCondRes]: private def res(lhs: Opt[Local], rhs: Result, rest: Block) = if rhs.isPure then lhs match case Some(lhs) => rhs match - case Value.Lit(lit) => PostCondRes(false, false, Map(lhs -> lit)) >=> analyze(rest) + case Value.Lit(lit, _) => PostCondRes(false, false, Map(lhs -> lit)) >=> analyze(rest) case _ => analyze(rest) case None => analyze(rest) else analyze(rest).markImpure diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala index bb6ab15766..03f7b33aef 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala @@ -17,7 +17,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S private val runStackSafePath: Path = runtimePath.selN(Tree.Ident("runStackSafe")) private val stackDepthPath: Path = runtimePath.selN(STACK_DEPTH_IDENT) - private def intLit(n: BigInt) = Value.Lit(Tree.IntLit(n)) + private def intLit(n: BigInt) = Value.IntLit(n) private def op(op: String, a: Path, b: Path) = Call(State.builtinOpsMap(op).asSimpleRef, (a.asArg :: b.asArg :: Nil) ne_:: Nil)(true, false, false) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 3bcf8a89ba..0c28b53563 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -350,7 +350,7 @@ class TailRecOpt(using State, TL, Raise): // The code used to continute the loop. val cont = if scc.funs.size === 1 then Continue(loopSym) - else Assign(curIdSym, Value.Lit(Tree.IntLit(dSymIds(calleeSym))), Continue(loopSym)) + else Assign(curIdSym, Value.IntLit(dSymIds(calleeSym)), Continue(loopSym)) // In some cases, we could have assignments like this: // param0 = whatever // param1 = @@ -433,7 +433,7 @@ class TailRecOpt(using State, TL, Raise): // Main args def mainArgs(rest: List[Path]) = (0 until paramList.size).toList.foldRight(rest): - case (n, acc) => DynSelect(tupleSym.asSimpleRef, Value.Lit(Tree.IntLit(n)), true) :: acc + case (n, acc) => DynSelect(tupleSym.asSimpleRef, Value.IntLit(n), true) :: acc // If the rest param exists, append a slice val (initialBlk: (Block => Block), pathList: List[Path]) = @@ -445,8 +445,8 @@ class TailRecOpt(using State, TL, Raise): .sel(Tree.Ident("Tuple"), State.tupleSymbol) .sel(Tree.Ident("slice"), State.tupleSliceSymbol), (tupleSym.asSimpleRef.asArg - :: Value.Lit(Tree.IntLit(paramList.length)).asArg - :: Value.Lit(Tree.IntLit(0)).asArg + :: Value.IntLit(paramList.length).asArg + :: Value.IntLit(0).asArg :: Nil) ne_:: Nil )(true, false, false) val blk = blockBuilder @@ -495,9 +495,9 @@ class TailRecOpt(using State, TL, Raise): else funs.map: f => val paramArgs = getParamSyms(f).map(s => s.asSimpleRef.asArg) val args = - Value.Lit(Tree.IntLit(dSymIds(f.dSym))).asArg + Value.IntLit(dSymIds(f.dSym)).asArg :: paramArgs - ::: List.fill(maxParamLen - paramArgs.length)(Value.Lit(Tree.UnitLit(false)).asArg) + ::: List.fill(maxParamLen - paramArgs.length)(Value.UnitLit(false).asArg) val newBod = Return( Call(sel, args ne_:: Nil)(true, false, false), ) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala index 95a8d49cd6..5eb54d3817 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala @@ -553,7 +553,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): transformedOgBody, mkReturnCall(parentFunSym, parentFunFvs)) case None => - Begin(transformedOgBody, Return(Value.Lit(Tree.UnitLit(true)))) + Begin(transformedOgBody, Return(Value.UnitLit(true))) val refreshedFvSymbols = restFnFvs(restFunId).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}"))) val bodyWithCorrectSymbols = new RefreshSymbol(refreshedFvSymbols.toMap).applyBlock(actualBody) FunDefn(tsym.owner, bms, tsym, refreshedFvSymbols.unzip._2.asParamList :: Nil, bodyWithCorrectSymbols)(N, annotations = PrivateModifier :: Nil) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala index f1316f6d72..6d0937b141 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala @@ -117,7 +117,7 @@ object PossibleTrackableTupleSelect: s match case Call( Select(Select(Value.SimpleRef(runtimeSym), Tree.Ident("Tuple")), Tree.Ident("get")), - (Arg(N, ref@Value.SimpleRef(scrut)) :: Arg(N, Value.Lit(Tree.IntLit(n))) :: Nil) :: Nil + (Arg(N, ref@Value.SimpleRef(scrut)) :: Arg(N, Value.Lit(Tree.IntLit(n), _)) :: Nil) :: Nil ) if runtimeSym is eState.runtimeSymbol => S(ref -> n.toInt) case _ => N @@ -598,7 +598,7 @@ class FlowPreAnalyzer(val pgrm: Program)(using case v@Value.SimpleRef(l) => applyValueSimpleRef(v, recordAffinity = true) case v@Value.MemberRef(_, _) => applyValueMemberRef(v, recordAffinity = true) case Value.This(sym) => () - case Value.Lit(lit) => () + case Value.Lit(lit, _) => () override def applyFunDefn(fun: FunDefn): Unit = ctxTracker.inFun(fun): @@ -1030,7 +1030,7 @@ class FlowConstraintsCollector( cc.constrain(processResult(qual), UnknownCons) cc.constrain(processResult(fld), UnknownCons) UnknownProd - case Value.Lit(lit) => UnknownProd + case Value.Lit(lit, _) => UnknownProd } end FlowConstraintsCollector diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index 7c654446b7..42f854315b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -57,7 +57,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: case Operand(prec: Int) def mkErr(errMsg: Message)(using Raise, Scope): Document = - doc"throw globalThis.Error(${result(Value.Lit(syntax.Tree.StrLit(errMsg.show)))})" + doc"throw globalThis.Error(${result(Value.StrLit(errMsg.show))})" def errExpr(errMsg: Message)(using Raise, Scope): Document = raise(ErrorReport(errMsg -> N :: Nil, @@ -174,8 +174,8 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: // * Module self-references use the module name itself instead of `this` scope.lookup_!(ts, r.toLoc) case Value.This(sym) => scope.findThis_!(sym) - case Value.Lit(Tree.StrLit(value)) => makeStringLiteral(value) - case Value.Lit(lit) => lit.idStr + case Value.Lit(Tree.StrLit(value), _) => makeStringLiteral(value) + case Value.Lit(lit, _) => lit.idStr case Value.MemberRef(bms, disamb) => if disamb.shouldBeLifted then doc"${scope.lookup_!(bms, bms.toLoc)}.class" else scope.lookup_!(bms, r.toLoc) @@ -257,9 +257,9 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: case Record(mut, flds) => val inner = bracketed(pre = "{", post = "}", insertBreak = true): flds.map: - case RcdArg(S(Value.Lit(IntLit(idx))), v) => + case RcdArg(S(Value.Lit(IntLit(idx), _)), v) => doc"${idx.toString}: ${result(v)}" - case RcdArg(S(Value.Lit(StrLit(idx))), v) => + case RcdArg(S(Value.Lit(StrLit(idx), _)), v) => doc"${if isValidIdentifier(idx) then idx else s"\"$idx\""}: ${result(v)}" case RcdArg(S(idx), v) => doc"[${result(idx)}]: ${result(v)}" @@ -326,7 +326,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: => lastBlkAssign(b) match // the one branch ends by assigning `nextInt` to `scrutSym` - case S(Assign(`scrutSym_`, Value.Lit(Tree.IntLit(nextInt)), _)) => + case S(Assign(`scrutSym_`, Value.Lit(Tree.IntLit(nextInt), _), _)) => unapplyImpl(rest, (curVal_, b) :: acc, S(scrut_), S(nextInt)) case _ => S((scrut_, (curVal_, b) :: acc, rest)) @@ -650,7 +650,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: doc" # $resJS" - case Return(Value.Lit(UnitLit(false))) => doc" # return${mkSemi}" + case Return(Value.Lit(UnitLit(false), _)) => doc" # return${mkSemi}" case Return(res) => doc" # return ${result(res)}${mkSemi}" case Match(scrut, Nil, els, rest) => @@ -665,7 +665,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: case SpecializedSwitch(scrut, cases, dflt, rest) => val switchBod = cases.foldLeft(doc""): (acc, arm) => val needsBreak = arm.isInstanceOf[SwitchCase.ExplicitBreak] - acc :: doc" # case ${result(Value.Lit(arm.litValue))}: #{ ${ + acc :: doc" # case ${result(Value.Lit(arm.litValue, arm.litValue.erasedType))}: #{ ${ // * Note: we use `block` here so that Scoped nodes will create proper brace sections, // * necessary since `case` clauses do not create a new scope, // * so something like `switch (x) { case 1: let y = 1; break; case 2: let y = 2 }` is ill-formed! diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 70bb4401d2..5f722d5d04 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -312,7 +312,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): ctx.fn_ctx.get(sym) match case None => k(ctx.findName(sym) |> sr) case Some(_) => bErrStop(msg"Unsupported value: This with function context") - case Value.Lit(lit) => k(Expr.Literal(lit)) + case Value.Lit(lit, _) => k(Expr.Literal(lit)) private def getClassOfField(p: DefinitionSymbol[?])(using ctx: Ctx)(using Raise, Scope): Local = @@ -344,7 +344,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): trace[Node](s"bPath { $p } begin", x => s"bPath end: ${x.show}"): p match case s @ Select(Value.MemberRef(sym, _), Tree.Ident("Unit")) if sym is ctx.builtinSym.runtimeSym.get => - bPath(Value.Lit(Tree.UnitLit(false)))(k) + bPath(Value.UnitLit(false))(k) case s @ DynSelect(qual, fld, arrayIdx) => bErrStop(msg"Unsupported dynamic selection") case s @ Select(qual, name) => @@ -504,15 +504,15 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): Node.Case(e, casesList, defaultCase) case Return(res) => bResult(res)(x => Node.Result(Ls(x))) case Throw(Instantiate(false, Select(Value.SimpleRef(_), ident), - Ls(Arg(N, Value.Lit(Tree.StrLit(e)))) :: Nil)) + Ls(Arg(N, Value.Lit(Tree.StrLit(e), _))) :: Nil)) if ident.name === "Error" => Node.Panic(e) case Throw(Instantiate(false, Select(Value.MemberRef(_, _), ident), - Ls(Arg(N, Value.Lit(Tree.StrLit(e)))) :: Nil)) + Ls(Arg(N, Value.Lit(Tree.StrLit(e), _))) :: Nil)) if ident.name === "Error" => Node.Panic(e) case Throw(Instantiate(false, Select(Value.This(_), ident), - Ls(Arg(N, Value.Lit(Tree.StrLit(e)))) :: Nil)) + Ls(Arg(N, Value.Lit(Tree.StrLit(e), _))) :: Nil)) if ident.name === "Error" => Node.Panic(e) case Label(label, loop, body, rest) => TODO("Label not supported") diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 085dbad291..2d0fb2b5cd 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -1023,7 +1023,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: errExtra: => Str, )(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr => Expr = fld match - case Value.Lit(IntLit(value)) if value.isValidInt => + case Value.Lit(IntLit(value), _) if value.isValidInt => val idx = value.toInt tupleRef => if idx >= 0 then i32.const(idx) @@ -1184,11 +1184,11 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ) def result(r: codegen.Result)(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = r match - case Value.Lit(BoolLit(value)) => + case Value.Lit(BoolLit(value), _) => ref.i31(i32.const(if value then 1 else 0)) - case Value.Lit(IntLit(value)) => + case Value.Lit(IntLit(value), _) => withValidIntLit(value, r.toLoc)(intVal => ref.i31(i32.const(intVal))) - case Value.Lit(StrLit(value)) => + case Value.Lit(StrLit(value), _) => val lit = internStringLiteral(value) val stringCtor = getOrLoadStrCtorFunction call( @@ -1421,10 +1421,10 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case Select(Value.This(sym), id) if (sym eq State.globalThisSymbol) && id.name == "Error" => return as.headOption match case S(arg) => arg.value match - case Value.Lit(BoolLit(value)) => ref.i31(i32.const(if value then 1 else 0)) - case Value.Lit(IntLit(value)) => + case Value.Lit(BoolLit(value), _) => ref.i31(i32.const(if value then 1 else 0)) + case Value.Lit(IntLit(value), _) => withValidIntLit(value, arg.value.toLoc)(intVal => ref.i31(i32.const(intVal))) - case Value.Lit(StrLit(_)) => result(arg.value) + case Value.Lit(StrLit(_), _) => result(arg.value) case unsupported => warnExpr( msg"WatBuilder::result for Instantiate(...) of `globalThis.Error(...)` with payload `${ diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index 1a23ab6f32..94c68a4941 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -311,7 +311,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e */ private def throwMatchErrorBlock = Throw(Instantiate(mut = false, Select(State.globalThisSymbol.asThis, Tree.Ident("Error"))(S(ctx.builtins.Error)), - (Value.Lit(syntax.Tree.StrLit("match error")).asArg :: Nil) :: Nil)) // TODO add failed-match scrutinee info + (Value.StrLit("match error").asArg :: Nil) :: Nil)) // TODO add failed-match scrutinee info import syntax.Keyword.{`if`, `while`} diff --git a/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls b/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls index 97a3436100..031cf9ef58 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls @@ -24,7 +24,9 @@ //│ ╙── ^ //│ —————————————| Lowered IR Tree |———————————————————————————————————————————————————————————————————— //│ Program: -//│ main = Return of Lit of IntLit of 2 +//│ main = Return of Lit: +//│ lit = IntLit of 2 +//│ erasedType = Primitive of Int //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— //│ = 2 @@ -53,8 +55,12 @@ print("Hi") //│ argss = Ls of //│ Ls of //│ Arg: -//│ value = Lit of StrLit of "Hi" -//│ rest = Return of Lit of IntLit of 2 +//│ value = Lit: +//│ lit = StrLit of "Hi" +//│ erasedType = Primitive of Str +//│ rest = Return of Lit: +//│ lit = IntLit of 2 +//│ erasedType = Primitive of Int //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— //│ > Hi //│ = 2 diff --git a/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls b/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls index 4c62749de8..039cf81ead 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls @@ -27,7 +27,9 @@ x + 1 //│ main = Scoped(syms = {x⁰}): //│ body = Assign: //│ lhs = x⁰ -//│ rhs = Lit of IntLit of 1 +//│ rhs = Lit: +//│ lit = IntLit of 1 +//│ erasedType = Primitive of Int //│ rest = Return of Call: //│ fun = SimpleRef of builtin:+⁰ //│ argss = Ls of @@ -35,7 +37,9 @@ x + 1 //│ Arg: //│ value = SimpleRef of x⁰ //│ Arg: -//│ value = Lit of IntLit of 1 +//│ value = Lit: +//│ lit = IntLit of 1 +//│ erasedType = Primitive of Int //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— //│ let x⁰; set x⁰ = 1; return +⁰(x⁰, 1) //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— From e3509d7249b336d94915b587f760fedfb3c6d045 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 28 May 2026 14:37:48 +0800 Subject: [PATCH 02/48] codegen: Add ErasedType for Value.This --- hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 2565dc9245..aa5069d0ec 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -1055,6 +1055,12 @@ enum Value extends Path with ProductWithExtraInfo: case MemberRef(bms: BlockMemberSymbol, disamb: DefinitionSymbol[?]) case This(sym: InnerSymbol) case Lit(lit: Literal, erasedType: ErasedType) + + /** Returns the [[`ErasedType`]] of this value. */ + def erasedType(using Ctx): ErasedType = this match + case This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => ErasedType.AnyRef(false, clsOrMod) + case Lit(lit, et) => et + case _ => ErasedType.objectRef override def extraInfo(using DebugPrinter): Str = this match case MemberRef(bms, disamb) => s"disamb=${disamb.showAsPlain}" From f0fe27c24d82fcc4d4572f29a436c5d18482a8c1 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 28 May 2026 14:51:19 +0800 Subject: [PATCH 03/48] codegen: Remove `Value.Lit.erasedType` --- .../src/main/scala/hkmc2/codegen/Block.scala | 36 ++++-------------- .../scala/hkmc2/codegen/BlockSimplifier.scala | 38 +++++++++---------- .../hkmc2/codegen/BlockTransformer.scala | 2 +- .../scala/hkmc2/codegen/BlockTraverser.scala | 2 +- .../hkmc2/codegen/BufferableTransform.scala | 6 +-- .../scala/hkmc2/codegen/DeadParamElim.scala | 2 +- .../scala/hkmc2/codegen/HandlerLowering.scala | 34 ++++++++--------- .../src/main/scala/hkmc2/codegen/Lifter.scala | 2 +- .../main/scala/hkmc2/codegen/Lowering.scala | 38 +++++++++---------- .../main/scala/hkmc2/codegen/Printer.scala | 2 +- .../codegen/ReflectionInstrumenter.scala | 6 +-- .../hkmc2/codegen/SpecializedSwitch.scala | 2 +- .../hkmc2/codegen/StackSafeTransform.scala | 2 +- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 12 +++--- .../hkmc2/codegen/deforest/Rewrite.scala | 2 +- .../codegen/flowAnalysis/FlowAnalysis.scala | 6 +-- .../scala/hkmc2/codegen/js/JSBuilder.scala | 16 ++++---- .../scala/hkmc2/codegen/llir/Builder.scala | 10 ++--- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 14 +++---- .../hkmc2/semantics/ucs/Normalization.scala | 2 +- .../src/test/mlscript/codegen/BasicTerms.mls | 12 ++---- .../test/mlscript/codegen/BlockPrinter.mls | 8 +--- 22 files changed, 112 insertions(+), 142 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index aa5069d0ec..4ff2d41b82 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -902,7 +902,7 @@ sealed abstract class Result extends AutoLocated: case Value.SimpleRef(l) => l.showAsPlain case Value.MemberRef(l, disamb) => s"${l.showAsPlain}${s"‹${disamb.showAsPlain}›"}" case Value.This(sym) => s"this[${sym.showAsPlain}]" - case Value.Lit(lit, _) => lit.idStr + case Value.Lit(lit) => lit.idStr case Select(q, n) => s"Select(${q.showDbg}, ${n.showDbg})" case DynSelect(q, fld, arrayIdx) => s"DynSelect(${q.showDbg}, ${fld.showDbg}, $arrayIdx)" case Call(fun, argss) => s"Call(${fun.showDbg}, [${ @@ -941,7 +941,7 @@ sealed abstract class Result extends AutoLocated: case Value.SimpleRef(l) => Vector.empty case Value.MemberRef(bms, disamb) => Vector.empty case Value.This(sym) => Vector.empty - case Value.Lit(lit, _) => Vector.single(lit) + case Value.Lit(lit) => Vector.single(lit) // TODO rm Lam from values and thus the need for this method def subBlocks: Ls[Block] = this match @@ -963,7 +963,7 @@ sealed abstract class Result extends AutoLocated: case Value.SimpleRef(l) => Set(l) case Value.MemberRef(bms, _) => Set(bms) case Value.This(sym) => Set.empty - case Value.Lit(lit, _) => Set.empty + case Value.Lit(lit) => Set.empty case DynSelect(qual, fld, arrayIdx) => qual.freeVars ++ fld.freeVars lazy val freeVarsLLIR: Set[Local] = this match @@ -986,7 +986,7 @@ sealed abstract class Result extends AutoLocated: case Some(d: TermDefinition) if d.companionClass.isDefined => Set.empty case _ => Set(l) case Value.This(sym) => Set.empty - case Value.Lit(lit, _) => Set.empty + case Value.Lit(lit) => Set.empty case DynSelect(qual, fld, arrayIdx) => qual.freeVarsLLIR ++ fld.freeVarsLLIR lazy val size: Int = this match @@ -997,8 +997,8 @@ sealed abstract class Result extends AutoLocated: case Tuple(mut, elems) => elems.iterator.map(_.value.size).sum case Record(mut, args) => args.iterator.map(arg => arg.idx.fold(0)(_.size) + arg.value.size).sum case _: Value.RefLike => 0 - case Value.Lit(l: Tree.StrLit, _) => l.value.length / 4 - case Value.Lit(lit, _) => 0 + case Value.Lit(l: Tree.StrLit) => l.value.length / 4 + case Value.Lit(lit) => 0 case DynSelect(qual, fld, arrayIdx) => qual.size + fld.size // * TODO: refine this very loose type @@ -1054,12 +1054,12 @@ enum Value extends Path with ProductWithExtraInfo: */ case MemberRef(bms: BlockMemberSymbol, disamb: DefinitionSymbol[?]) case This(sym: InnerSymbol) - case Lit(lit: Literal, erasedType: ErasedType) + case Lit(lit: Literal) /** Returns the [[`ErasedType`]] of this value. */ def erasedType(using Ctx): ErasedType = this match case This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => ErasedType.AnyRef(false, clsOrMod) - case Lit(lit, et) => et + case Lit(lit) => lit.erasedType case _ => ErasedType.objectRef override def extraInfo(using DebugPrinter): Str = this match @@ -1078,27 +1078,7 @@ object Value: case SimpleRef(l) => l case MemberRef(bms, _) => bms case This(sym) => sym - - object IntLit: - def apply(i: BigInt): Value.Lit = - Value.Lit(Tree.IntLit(i), ErasedType.Primitive(PrimitiveType.Int)) - - object DecLit: - def apply(d: BigDecimal): Value.Lit = - Value.Lit(Tree.DecLit(d), ErasedType.Primitive(PrimitiveType.Num)) - - object StrLit: - def apply(s: Str): Value.Lit = - Value.Lit(Tree.StrLit(s), ErasedType.Primitive(PrimitiveType.Str)) - object UnitLit: - def apply(isNullNotUndefined: Bool): Value.Lit = - Value.Lit(Tree.UnitLit(isNullNotUndefined), ErasedType.Primitive(PrimitiveType.Unit)) - - object BoolLit: - def apply(b: Bool): Value.Lit = - Value.Lit(Tree.BoolLit(b), ErasedType.Primitive(PrimitiveType.Bool)) - @deprecated("Use Value.SimpleRef, Value.MemberRef, or Value.This instead.") object Ref: def apply(l: Local, disamb: Opt[DefinitionSymbol[?]]): Value.RefLike = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala index 04de8d07d6..3292609efd 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala @@ -219,7 +219,7 @@ class BlockSimplifier case Value.SimpleRef(loc) if localVars.contains(loc) && !definedVars.contains(loc) => registerChange(s"${loc.showDbg} is never assigned; replacing read with undefined") // if !symbolsToPreserve(loc) then removedLocals += loc - k(Value.UnitLit(false)) + k(Value.Lit(syntax.Tree.UnitLit(false))) case _ => super.applyValue(v)(k) override def applyBlock(b: Block): Block = b match @@ -646,7 +646,7 @@ class BlockSimplifier assignedResults.get(r).fold(giveUp)(getShapesA) case Value.MemberRef(r, sym: ModuleOrObjectSymbol) => Set.single(sym) - case Value.Lit(lit, _) => Set.single(lit) + case Value.Lit(lit) => Set.single(lit) case _ => giveUp var shapes = if deadBranchRemoval then getShapes(scrut2) else giveUp @@ -757,7 +757,7 @@ class BlockSimplifier if litValue =/= false then ass.rhs match - case v @ Value.Lit(lit, _) => + case v @ Value.Lit(lit) => if litValue === true then litValue = v else if litValue =/= v then @@ -790,7 +790,7 @@ class BlockSimplifier litValue match case true => registerChange(s"${loc.showDbg} ~> undefined") - return k(Value.UnitLit(false)) + return k(Value.Lit(syntax.Tree.UnitLit(false))) case lit: Value => registerChange(s"${loc.showDbg} ~> ${lit.showDbg}") return k(lit) @@ -871,22 +871,22 @@ class BlockSimplifier // TODO: mv to smart ctor of Call import syntax.Tree.*, Value.Lit val builtinEval: PartialFunction[(Str, List[Value]), Value] = - case ("+", (lit @ Lit(IntLit(v1), _)) :: Nil) => lit - case ("+", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.IntLit(v1 + v2) - case ("-", Lit(IntLit(v1), _) :: Nil) => Value.IntLit(-v1) - case ("-", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.IntLit(v1 - v2) - case ("*", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.IntLit(v1 * v2) + case ("+", (lit @ Lit(IntLit(v1))) :: Nil) => lit + case ("+", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(IntLit(v1 + v2)) + case ("-", Lit(IntLit(v1)) :: Nil) => Lit(IntLit(-v1)) + case ("-", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(IntLit(v1 - v2)) + case ("*", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(IntLit(v1 * v2)) // * For "/", should check for 0 and return a DecLit - case ("%", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.IntLit(v1 % v2) - case ("===", Lit(l1, _) :: Lit(l2, _) :: Nil) => Value.BoolLit(l1 == l2) - case ("!==", Lit(l1, _) :: Lit(l2, _) :: Nil) => Value.BoolLit(l1 != l2) - case ("<", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.BoolLit(v1 < v2) - case ("<=", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.BoolLit(v1 <= v2) - case (">", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.BoolLit(v1 > v2) - case (">=", Lit(IntLit(v1), _) :: Lit(IntLit(v2), _) :: Nil) => Value.BoolLit(v1 >= v2) - case ("&&", Lit(BoolLit(v1), _) :: Lit(BoolLit(v2), _) :: Nil) => Value.BoolLit(v1 && v2) - case ("||", Lit(BoolLit(v1), _) :: Lit(BoolLit(v2), _) :: Nil) => Value.BoolLit(v1 || v2) - case ("!", Lit(BoolLit(v), _) :: Nil) => Value.BoolLit(!v) + case ("%", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(IntLit(v1 % v2)) + case ("===", Lit(l1) :: Lit(l2) :: Nil) => Lit(BoolLit(l1 == l2)) + case ("!==", Lit(l1) :: Lit(l2) :: Nil) => Lit(BoolLit(l1 != l2)) + case ("<", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(BoolLit(v1 < v2)) + case ("<=", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(BoolLit(v1 <= v2)) + case (">", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(BoolLit(v1 > v2)) + case (">=", Lit(IntLit(v1)) :: Lit(IntLit(v2)) :: Nil) => Lit(BoolLit(v1 >= v2)) + case ("&&", Lit(BoolLit(v1)) :: Lit(BoolLit(v2)) :: Nil) => Lit(BoolLit(v1 && v2)) + case ("||", Lit(BoolLit(v1)) :: Lit(BoolLit(v2)) :: Nil) => Lit(BoolLit(v1 || v2)) + case ("!", Lit(BoolLit(v)) :: Nil) => Lit(BoolLit(!v)) end DataFlowAnalysis diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala index 9668580742..c30ebbddfc 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala @@ -191,7 +191,7 @@ class BlockTransformer(subst: SymbolSubst): case Value.This(sym) => val sym2 = sym.subst k(if (sym2 is sym) then v else sym2.asThis.withLocOf(v)) - case Value.Lit(lit, _) => k(v) + case Value.Lit(lit) => k(v) def applyLocal(sym: Local): Local = sym.subst diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala index 5f66bb41fb..24f44aa1ee 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala @@ -76,7 +76,7 @@ class BlockTraverser: bms.traverse disamb.traverse case Value.This(sym) => sym.traverse - case Value.Lit(lit, _) => () + case Value.Lit(lit) => () def applyLocal(sym: Local): Unit = sym.traverse diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala index 1988ada2fc..bd1f989f1b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala @@ -38,11 +38,11 @@ class BufferableTransform()(using Ctx, State, Raise): def mkFieldReplacer(buf: VarSymbol, baseIdx: VarSymbol, symMap: Map[Symbol, Symbol]) = def getOffset(off: Int)(k: Path => Block): Block = val idxSymbol = new TempSymbol(N, "idx") - Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.IntLit(off).asArg :: Nil) ne_:: Nil)(true, false, false), + Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), k(DynSelect(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true)))) def assignToOffset(off: Int, r: Result, rst: Block) = val idxSymbol = new TempSymbol(N, "idx") - Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.IntLit(off).asArg :: Nil) ne_:: Nil)(true, false, false), + Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), AssignDynField(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true, r, applyBlock(rst)))) new BlockTransformer(SymbolSubst.Id): override def applyLocal(sym: Local): Local = symMap.getOrElse(sym, sym) @@ -91,7 +91,7 @@ class BufferableTransform()(using Ctx, State, Raise): fakeCtor :: cls.methods.map(transformFunDefn(_, false)), Nil, clsSizeSym -> clsSizeTermSym :: Nil, - Define(ValDefn(clsSizeTermSym, clsSizeSym, Value.IntLit(fields.size))(N, Nil), End()), + Define(ValDefn(clsSizeTermSym, clsSizeSym, Value.Lit(Tree.IntLit(fields.size)))(N, Nil), End()), annotations = Nil, ) k: diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala index cd6960f8db..2b75982a87 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala @@ -238,7 +238,7 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise): override def applyValue(v: Value)(k: Value => Block): Block = v match case ref@Value.SimpleRef(l: VarSymbol) if activeEliminatedParams(l) => - k(Value.UnitLit(false).withLocOf(ref)) + k(Value.Lit(Tree.UnitLit(false)).withLocOf(ref)) case _ => super.applyValue(v)(k) override def applyBlock(b: Block): Block = b match diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index 0b34e1adc2..11263f63b2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -24,12 +24,12 @@ object HandlerLowering: private val nextIdent: Tree.Ident = Tree.Ident("next") private val lastIdent: Tree.Ident = Tree.Ident("last") private val contTraceIdent: Tree.Ident = Tree.Ident("contTrace") - private def unit = Value.UnitLit(true) - private def intLit(i: BigInt) = Value.IntLit(i) + private def unit = Value.Lit(Tree.UnitLit(true)) + private def intLit(i: BigInt) = Value.Lit(Tree.IntLit(i)) private def locToStr(loc: Loc) = val (line, _, col) = loc.origin.fph.getLineColAt(loc.spanStart) - Value.StrLit(s"${loc.origin.fileName.last}:${line + loc.origin.startLineNum - 1}:$col") + Value.Lit(Tree.StrLit(s"${loc.origin.fileName.last}:${line + loc.origin.startLineNum - 1}:$col")) extension (p: Path) def pc = p.selN(pcIdent) @@ -121,7 +121,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, private def rtThrowMsg(msg: Str) = Throw( Instantiate(mut = false, State.globalThisSymbol.asThis.selN(Tree.Ident("Error")), - (Value.StrLit(msg).asArg :: Nil) :: Nil) + (Value.Lit(Tree.StrLit(msg)).asArg :: Nil) :: Nil) ) object PureCall: @@ -138,18 +138,18 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, object StateTransition: private val transitionSymbol = freshTmp("transition") def apply(uid: StateId) = - Return(PureCall(transitionSymbol.asSimpleRef, List(Value.IntLit(uid)))) + Return(PureCall(transitionSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid))))) def unapply(blk: Block) = blk match - case Return(PureCall(Value.SimpleRef(`transitionSymbol`), List(Value.Lit(Tree.IntLit(uid), _)))) => + case Return(PureCall(Value.SimpleRef(`transitionSymbol`), List(Value.Lit(Tree.IntLit(uid))))) => S(uid) case _ => N object Unwind: private val unwindSymbol = freshTmp("unwind") def apply(uid: StateId, loc: Value) = - Return(PureCall(unwindSymbol.asSimpleRef, List(Value.IntLit(uid), loc))) + Return(PureCall(unwindSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid)), loc))) def unapply(blk: Block) = blk match - case Return(PureCall(Value.SimpleRef(`unwindSymbol`), List(Value.Lit(Tree.IntLit(uid), _), loc: Value))) => + case Return(PureCall(Value.SimpleRef(`unwindSymbol`), List(Value.Lit(Tree.IntLit(uid)), loc: Value))) => S(uid, loc) case _ => N @@ -533,9 +533,9 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case Scoped(syms, body) => syms case _ => Set() val varList = scopedVars.toList.sortBy(_.uid) - val debugInfo = Value.StrLit(debugNme).asArg :: varList.zipWithIndex.filter(_._1.isInstanceOf[VarSymbol]) + val debugInfo = Value.Lit(Tree.StrLit(debugNme)).asArg :: varList.zipWithIndex.filter(_._1.isInstanceOf[VarSymbol]) .flatMap: (sym, idx) => - List(intLit(idx), Value.StrLit(sym.nme)) + List(intLit(idx), Value.Lit(Tree.StrLit(sym.nme))) .map(_.asArg) val debugInfoSym = freshTmp(s"$debugNme$$debugInfo") // TODO: properly support spread argument by calculating the correct length. @@ -595,12 +595,12 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case _ => super.applyDefn(defn)(k) val b = preTransform.applyBlock(blk) if h.inCtor then - return translateIllegalEffectCtx(b, Call(paths.illegalEffectPath, (Value.StrLit("in a constructor").asArg :: Nil) ne_:: Nil)(true, true, false)) + return translateIllegalEffectCtx(b, Call(paths.illegalEffectPath, (Value.Lit(Tree.StrLit("in a constructor")).asArg :: Nil) ne_:: Nil)(true, true, false)) if h.inTopLevel then - return translateIllegalEffectCtx(b, Call(paths.topLevelEffectPath, (Value.BoolLit(opt.debug).asArg :: Nil) ne_:: Nil)(true, false, false)) + return translateIllegalEffectCtx(b, Call(paths.topLevelEffectPath, (Value.Lit(Tree.BoolLit(opt.debug)).asArg :: Nil) ne_:: Nil)(true, false, false)) val ctx = h.asInstanceOf[HandlerCtx.FunctionLike].ctx if ctx.inGetter then - return translateIllegalEffectCtx(b, Call(paths.illegalEffectPath, (Value.StrLit("in a getter").asArg :: Nil) ne_:: Nil)(true, false, false)) + return translateIllegalEffectCtx(b, Call(paths.illegalEffectPath, (Value.Lit(Tree.StrLit("in a getter")).asArg :: Nil) ne_:: Nil)(true, false, false)) given FunctionCtx = ctx val parts = partitionBlock(b) stackSafetyMap += ctx.resumeInfo.currentStackSafetySym -> @@ -621,7 +621,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val segmentTailTransform = new BlockTransformerShallow(SymbolSubst.Id): override def applyBlock(b: Block) = b match case StateTransition(uid) => - Assign(pcVar, Value.IntLit(uid), Continue(mainLoopLbl)) + Assign(pcVar, Value.Lit(Tree.IntLit(uid)), Continue(mainLoopLbl)) case Unwind(uid, loc) => ctx.doUnwind(loc, uid, vars)(using paths) case _ => super.applyBlock(b) @@ -641,7 +641,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case StateTransition(uid) => assert(uid === nextState) if isSimple then - Assign(pcVar, Value.IntLit(uid), End()) + Assign(pcVar, Value.Lit(Tree.IntLit(uid)), End()) else Break(lblSym) case Unwind(uid, loc) => @@ -651,7 +651,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, if isSimple then transformed else Label( lblSym, false, transformed, - Assign(pcVar, Value.IntLit(nextState), End()) + Assign(pcVar, Value.Lit(Tree.IntLit(nextState)), End()) ) line match case head :: next => @@ -711,7 +711,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, Case.Lit(Tree.IntLit(-1)) -> Assign(pcVar, intLit(parts.entry), End()) :: Nil, S(restoreVars - .assignFieldN(paths.runtimePath, new Tree.Ident("resumePc"), Value.IntLit(-1)).end), + .assignFieldN(paths.runtimePath, new Tree.Ident("resumePc"), Value.Lit(Tree.IntLit(-1))).end), mainLoop)) private def translateCtorLike(b: Block, thisPath: Path, isModCtor: Bool)(using h: HandlerCtx): Block = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index 72853eac5f..56a802ebc3 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -648,7 +648,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): protected final def addExtraSyms(b: Block, captureSym: Local, objSyms: Iterable[Local], define: Bool): Block = if hasCapture then - val undef = Value.UnitLit(false).asArg + val undef = Value.Lit(Tree.UnitLit(false)).asArg val inst = Instantiate( true, captureClass.sym.asMemberRef(captureClass.isym), diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index b47be7202e..4cd5db7f7b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -29,7 +29,7 @@ object Ret extends TailOp: object ImplctRet extends TailOp: def apply(r: Result): Block = r match - case Value.Lit(Tree.UnitLit(false), _) => End() + case Value.Lit(Tree.UnitLit(false)) => End() case _ => Return(r) object Thrw extends TailOp: def apply(r: Result): Block = Throw(r) @@ -75,7 +75,7 @@ import LoweringCtx.loweringCtx object Lowering: def compError: Block = - Throw(Value.StrLit("This code cannot be run as its compilation yielded an error.")) + Throw(Value.Lit(Tree.StrLit("This code cannot be run as its compilation yielded an error."))) def fail(err: ErrorReport)(using Raise): Block = raise(err) @@ -645,7 +645,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case st.UnitVal() => k(unit) case st.Lit(lit) => if lit =/= Tree.UnitLit(false) then warnStmt - k(Value.Lit(lit, lit.erasedType)) + k(Value.Lit(lit)) case st.Ret(res) => returnedTerm(res) case st.Throw(res) => @@ -688,7 +688,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case st.Break(label, result, value) => value match case S(v) => term(v)(r => Assign(result, r, Break(label))) - case N => Assign(result, Value.UnitLit(false), Break(label)) + case N => Assign(result, Value.Lit(Tree.UnitLit(false)), Break(label)) case st.Continue(label) => Continue(label) case st.Asc(lhs, rhs) => @@ -733,14 +733,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): if k.isInstanceOf[TailOp] then Match( ar1, (Case.Lit(posLit) -> term_nonTail(arg2)(k)) :: Nil, - S(k(Value.Lit(negLit, negLit.erasedType))), + S(k(Value.Lit(negLit))), Unreachable("tail operation in branches"), ) else val ts = loweringCtx.registerTempSymbol(N) Match( ar1, (Case.Lit(posLit) -> term_nonTail(arg2)(Assign(ts, _, End()))) :: Nil, - S(Assign(ts, Value.Lit(negLit, negLit.erasedType), End())), + S(Assign(ts, Value.Lit(negLit), End())), k(ts.asSimpleRef), ) sym match @@ -1019,10 +1019,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): def setupSymbol(symbol: Local)(k: Result => Block)(using LoweringCtx): Block = k(Instantiate(mut = false, State.termSymbol.asSimpleRef.selSN("Symbol"), - (Value.StrLit(symbol.nme).asArg :: Nil) :: Nil)) + (Value.Lit(Tree.StrLit(symbol.nme)).asArg :: Nil) :: Nil)) def quotePattern(p: FlatPattern)(k: Result => Block)(using LoweringCtx): Block = p match - case FlatPattern.Lit(lit) => setupTerm("LitPattern", Value.Lit(lit, lit.erasedType) :: Nil)(k) + case FlatPattern.Lit(lit) => setupTerm("LitPattern", Value.Lit(lit) :: Nil)(k) case _ => // TODO fail: ErrorReport( @@ -1068,10 +1068,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): def quote(t: st)(k: Result => Block)(using LoweringCtx): Block = t match case Lit(lit) => - setupTerm("Lit", Value.Lit(lit, lit.erasedType) :: Nil)(k) + setupTerm("Lit", Value.Lit(lit) :: Nil)(k) case Ref(sym) if Elaborator.binaryOps.contains(sym.nme) => // builtin symbols val l = loweringCtx.registerTempSymbol(N) - setupTerm("Builtin", Value.StrLit(sym.nme) :: Nil)(k) + setupTerm("Builtin", Value.Lit(Tree.StrLit(sym.nme)) :: Nil)(k) case Resolved(Ref(sym), disamb) => sym match case sym: BlockMemberSymbol => k(sym.asMemberRef(disamb)) @@ -1081,8 +1081,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case SynthSel(Ref(sym: ModuleOrObjectSymbol), name) => // Local cross-stage references setupSymbol(sym): r1 => val l1, l2 = loweringCtx.registerTempSymbol(N) - Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.UnitLit(false) :: Nil)(r2 => - Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.StrLit(name.name) :: Nil)(k)) + Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.Lit(syntax.Tree.UnitLit(false)) :: Nil)(r2 => + Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) )) case SynthSel(Ref(sym: BlockMemberSymbol), name) => // Multi-file cross-stage references if config.qqEnabled then fail: @@ -1097,8 +1097,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val basePath = base.up val targetPath = filename val relPath = targetPath.relativeTo(basePath).map(_.toString).getOrElse(targetPath.toString) - Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.StrLit(relPath) :: Nil)(r2 => - Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.StrLit(name.name) :: Nil)(k)) + Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.Lit(syntax.Tree.StrLit(relPath)) :: Nil)(r2 => + Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) )) case _ => fail: ErrorReport( @@ -1419,7 +1419,7 @@ trait LoweringSelSanityChecks(using Config, TL, Raise, State) .ifthen(selRes.asSimpleRef, Case.Lit(syntax.Tree.UnitLit(false)), Throw(Instantiate(mut = false, Select(State.globalThisSymbol.asThis, Tree.Ident("Error"))(N), - (Value.StrLit(s"Access to required field '${nme.name}' yielded 'undefined'").asArg :: Nil) :: Nil)) + (Value.Lit(syntax.Tree.StrLit(s"Access to required field '${nme.name}' yielded 'undefined'")).asArg :: Nil) :: Nil)) ) .rest(k(selRes.asSimpleRef)) @@ -1475,10 +1475,10 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) val resInspectedSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogResInspected") - val psSymArgs = psInspectedSyms.zipWithIndex.foldRight[Ls[Arg]](Arg(N, Value.StrLit(")")) :: Nil): + val psSymArgs = psInspectedSyms.zipWithIndex.foldRight[Ls[Arg]](Arg(N, Value.Lit(Tree.StrLit(")"))) :: Nil): case (((s, p), i), acc) => if i == psInspectedSyms.length - 1 then Arg(N, s.asSimpleRef) :: acc - else Arg(N, s.asSimpleRef) :: Arg(N, Value.StrLit(", ")) :: acc + else Arg(N, s.asSimpleRef) :: Arg(N, Value.Lit(Tree.StrLit(", "))) :: acc val tmp1, tmp2, tmp3 = loweringCtx.registerTempSymbol(N) @@ -1488,7 +1488,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) assignStmts( enterMsgSym -> pureCall( strConcatFn, - Arg(N, Value.StrLit(s"CALL ${name.getOrElse("[arrow function]")}(")) :: psSymArgs + Arg(N, Value.Lit(Tree.StrLit(s"CALL ${name.getOrElse("[arrow function]")}("))) :: psSymArgs ), tmp1 -> pureCall(traceLogFn, Arg(N, enterMsgSym.asSimpleRef) :: Nil), prevIndentLvlSym -> pureCall(traceLogIndentFn, Nil) @@ -1499,7 +1499,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) resInspectedSym -> pureCall(inspectFn, Arg(N, resSym.asSimpleRef) :: Nil), retMsgSym -> pureCall( strConcatFn, - Arg(N, Value.StrLit("=> ")) :: Arg(N, resInspectedSym.asSimpleRef) :: Nil + Arg(N, Value.Lit(Tree.StrLit("=> "))) :: Arg(N, resInspectedSym.asSimpleRef) :: Nil ), tmp2 -> pureCall(traceLogResetFn, Arg(N, prevIndentLvlSym.asSimpleRef) :: Nil), tmp3 -> pureCall(traceLogFn, Arg(N, retMsgSym.asSimpleRef) :: Nil) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index b5dfc905f8..7b80670ce8 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -168,7 +168,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case Value.MemberRef(bms, disamb) => showSymbol(bms.nme, S(disamb)) case Value.This(sym) if sym === State.globalThisSymbol => showSymbol(sym.nme, S(sym.asDefnSym)) case Value.This(sym) => doc"${print(sym)}.this" - case Value.Lit(lit, _) => doc"${lit.idStr}" + case Value.Lit(lit) => doc"${lit.idStr}" def print(path: Path)(using Scope): Document = path match case sel @ Select(qual, name) => diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala index dec706e902..62c13f56ea 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala @@ -38,7 +38,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n case b: Bool => Tree.BoolLit(b) case s: Str => Tree.StrLit(s) case n: BigDecimal => Tree.DecLit(n) - Value.Lit(l, l.erasedType) + Value.Lit(l) extension [A, B](ls: Ls[(A => B) => B]) def collectApply(f: Ls[A] => B): B = @@ -213,7 +213,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n raise(ErrorReport(msg"Instantiate with multiple argument lists not supported in staged module." -> r.toLoc :: Nil)) End() // desugar Runtime.Tuple.get into Select - case Call(fun, Ls(Arg(_, scrut), Arg(_, Value.Lit(Tree.IntLit(idx), _))) :: _) if fun == Value.SimpleRef(State.runtimeSymbol).selSN("Tuple").selSN("get") => + case Call(fun, Ls(Arg(_, scrut), Arg(_, Value.Lit(Tree.IntLit(idx)))) :: _) if fun == Value.SimpleRef(State.runtimeSymbol).selSN("Tuple").selSN("get") => transformPath(Select(scrut, Tree.Ident(idx.toString()))(N))(k) case Call(fun, argss) => val stagedFunPath = fun match @@ -261,7 +261,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n transformOption(pOpt, transformParamList)(k) def transformCase(cse: Case)(using Context)(k: Path => Block): Block = cse match - case Case.Lit(lit) => blockCtor("Lit", Ls(Value.Lit(lit, lit.erasedType)))(k) + case Case.Lit(lit) => blockCtor("Lit", Ls(Value.Lit(lit)))(k) case Case.Cls(cls, path) => transformSymbol(cls): cls => transformPath(path): path => diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala index 5077bab0df..55b88d3c89 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala @@ -118,7 +118,7 @@ private object PostCondAnalysisImpl extends CachedAnalysis[Block, PostCondRes]: private def res(lhs: Opt[Local], rhs: Result, rest: Block) = if rhs.isPure then lhs match case Some(lhs) => rhs match - case Value.Lit(lit, _) => PostCondRes(false, false, Map(lhs -> lit)) >=> analyze(rest) + case Value.Lit(lit) => PostCondRes(false, false, Map(lhs -> lit)) >=> analyze(rest) case _ => analyze(rest) case None => analyze(rest) else analyze(rest).markImpure diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala index 03f7b33aef..bb6ab15766 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala @@ -17,7 +17,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S private val runStackSafePath: Path = runtimePath.selN(Tree.Ident("runStackSafe")) private val stackDepthPath: Path = runtimePath.selN(STACK_DEPTH_IDENT) - private def intLit(n: BigInt) = Value.IntLit(n) + private def intLit(n: BigInt) = Value.Lit(Tree.IntLit(n)) private def op(op: String, a: Path, b: Path) = Call(State.builtinOpsMap(op).asSimpleRef, (a.asArg :: b.asArg :: Nil) ne_:: Nil)(true, false, false) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 0c28b53563..3bcf8a89ba 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -350,7 +350,7 @@ class TailRecOpt(using State, TL, Raise): // The code used to continute the loop. val cont = if scc.funs.size === 1 then Continue(loopSym) - else Assign(curIdSym, Value.IntLit(dSymIds(calleeSym)), Continue(loopSym)) + else Assign(curIdSym, Value.Lit(Tree.IntLit(dSymIds(calleeSym))), Continue(loopSym)) // In some cases, we could have assignments like this: // param0 = whatever // param1 = @@ -433,7 +433,7 @@ class TailRecOpt(using State, TL, Raise): // Main args def mainArgs(rest: List[Path]) = (0 until paramList.size).toList.foldRight(rest): - case (n, acc) => DynSelect(tupleSym.asSimpleRef, Value.IntLit(n), true) :: acc + case (n, acc) => DynSelect(tupleSym.asSimpleRef, Value.Lit(Tree.IntLit(n)), true) :: acc // If the rest param exists, append a slice val (initialBlk: (Block => Block), pathList: List[Path]) = @@ -445,8 +445,8 @@ class TailRecOpt(using State, TL, Raise): .sel(Tree.Ident("Tuple"), State.tupleSymbol) .sel(Tree.Ident("slice"), State.tupleSliceSymbol), (tupleSym.asSimpleRef.asArg - :: Value.IntLit(paramList.length).asArg - :: Value.IntLit(0).asArg + :: Value.Lit(Tree.IntLit(paramList.length)).asArg + :: Value.Lit(Tree.IntLit(0)).asArg :: Nil) ne_:: Nil )(true, false, false) val blk = blockBuilder @@ -495,9 +495,9 @@ class TailRecOpt(using State, TL, Raise): else funs.map: f => val paramArgs = getParamSyms(f).map(s => s.asSimpleRef.asArg) val args = - Value.IntLit(dSymIds(f.dSym)).asArg + Value.Lit(Tree.IntLit(dSymIds(f.dSym))).asArg :: paramArgs - ::: List.fill(maxParamLen - paramArgs.length)(Value.UnitLit(false).asArg) + ::: List.fill(maxParamLen - paramArgs.length)(Value.Lit(Tree.UnitLit(false)).asArg) val newBod = Return( Call(sel, args ne_:: Nil)(true, false, false), ) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala index 5eb54d3817..95a8d49cd6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala @@ -553,7 +553,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): transformedOgBody, mkReturnCall(parentFunSym, parentFunFvs)) case None => - Begin(transformedOgBody, Return(Value.UnitLit(true))) + Begin(transformedOgBody, Return(Value.Lit(Tree.UnitLit(true)))) val refreshedFvSymbols = restFnFvs(restFunId).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}"))) val bodyWithCorrectSymbols = new RefreshSymbol(refreshedFvSymbols.toMap).applyBlock(actualBody) FunDefn(tsym.owner, bms, tsym, refreshedFvSymbols.unzip._2.asParamList :: Nil, bodyWithCorrectSymbols)(N, annotations = PrivateModifier :: Nil) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala index 6d0937b141..f1316f6d72 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala @@ -117,7 +117,7 @@ object PossibleTrackableTupleSelect: s match case Call( Select(Select(Value.SimpleRef(runtimeSym), Tree.Ident("Tuple")), Tree.Ident("get")), - (Arg(N, ref@Value.SimpleRef(scrut)) :: Arg(N, Value.Lit(Tree.IntLit(n), _)) :: Nil) :: Nil + (Arg(N, ref@Value.SimpleRef(scrut)) :: Arg(N, Value.Lit(Tree.IntLit(n))) :: Nil) :: Nil ) if runtimeSym is eState.runtimeSymbol => S(ref -> n.toInt) case _ => N @@ -598,7 +598,7 @@ class FlowPreAnalyzer(val pgrm: Program)(using case v@Value.SimpleRef(l) => applyValueSimpleRef(v, recordAffinity = true) case v@Value.MemberRef(_, _) => applyValueMemberRef(v, recordAffinity = true) case Value.This(sym) => () - case Value.Lit(lit, _) => () + case Value.Lit(lit) => () override def applyFunDefn(fun: FunDefn): Unit = ctxTracker.inFun(fun): @@ -1030,7 +1030,7 @@ class FlowConstraintsCollector( cc.constrain(processResult(qual), UnknownCons) cc.constrain(processResult(fld), UnknownCons) UnknownProd - case Value.Lit(lit, _) => UnknownProd + case Value.Lit(lit) => UnknownProd } end FlowConstraintsCollector diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index 4a9764c3f8..816d9c3f2a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -57,7 +57,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: case Operand(prec: Int) def mkErr(errMsg: Message)(using Raise, Scope): Document = - doc"throw globalThis.Error(${result(Value.StrLit(errMsg.show))})" + doc"throw globalThis.Error(${result(Value.Lit(syntax.Tree.StrLit(errMsg.show)))})" def errExpr(errMsg: Message)(using Raise, Scope): Document = raise(ErrorReport(errMsg -> N :: Nil, @@ -174,8 +174,8 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: // * Module self-references use the module name itself instead of `this` scope.lookup_!(ts, r.toLoc) case Value.This(sym) => scope.findThis_!(sym) - case Value.Lit(Tree.StrLit(value), _) => makeStringLiteral(value) - case Value.Lit(lit, _) => lit.idStr + case Value.Lit(Tree.StrLit(value)) => makeStringLiteral(value) + case Value.Lit(lit) => lit.idStr case Value.MemberRef(bms, disamb) => if disamb.shouldBeLifted then doc"${scope.lookup_!(bms, bms.toLoc)}.class" else scope.lookup_!(bms, r.toLoc) @@ -257,9 +257,9 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: case Record(mut, flds) => val inner = bracketed(pre = "{", post = "}", insertBreak = true): flds.map: - case RcdArg(S(Value.Lit(IntLit(idx), _)), v) => + case RcdArg(S(Value.Lit(IntLit(idx))), v) => doc"${idx.toString}: ${result(v)}" - case RcdArg(S(Value.Lit(StrLit(idx), _)), v) => + case RcdArg(S(Value.Lit(StrLit(idx))), v) => doc"${if isValidIdentifier(idx) then idx else s"\"$idx\""}: ${result(v)}" case RcdArg(S(idx), v) => doc"[${result(idx)}]: ${result(v)}" @@ -326,7 +326,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: => lastBlkAssign(b) match // the one branch ends by assigning `nextInt` to `scrutSym` - case S(Assign(`scrutSym_`, Value.Lit(Tree.IntLit(nextInt), _), _)) => + case S(Assign(`scrutSym_`, Value.Lit(Tree.IntLit(nextInt)), _)) => unapplyImpl(rest, (curVal_, b) :: acc, S(scrut_), S(nextInt)) case _ => S((scrut_, (curVal_, b) :: acc, rest)) @@ -645,7 +645,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: doc" # $resJS" - case Return(Value.Lit(UnitLit(false), _)) => doc" # return${mkSemi}" + case Return(Value.Lit(UnitLit(false))) => doc" # return${mkSemi}" case Return(res) => doc" # return ${result(res)}${mkSemi}" case Match(scrut, Nil, els, rest) => @@ -660,7 +660,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: case SpecializedSwitch(scrut, cases, dflt, rest) => val switchBod = cases.foldLeft(doc""): (acc, arm) => val needsBreak = arm.isInstanceOf[SwitchCase.ExplicitBreak] - acc :: doc" # case ${result(Value.Lit(arm.litValue, arm.litValue.erasedType))}: #{ ${ + acc :: doc" # case ${result(Value.Lit(arm.litValue))}: #{ ${ // * Note: we use `block` here so that Scoped nodes will create proper brace sections, // * necessary since `case` clauses do not create a new scope, // * so something like `switch (x) { case 1: let y = 1; break; case 2: let y = 2 }` is ill-formed! diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 5f722d5d04..70bb4401d2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -312,7 +312,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): ctx.fn_ctx.get(sym) match case None => k(ctx.findName(sym) |> sr) case Some(_) => bErrStop(msg"Unsupported value: This with function context") - case Value.Lit(lit, _) => k(Expr.Literal(lit)) + case Value.Lit(lit) => k(Expr.Literal(lit)) private def getClassOfField(p: DefinitionSymbol[?])(using ctx: Ctx)(using Raise, Scope): Local = @@ -344,7 +344,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): trace[Node](s"bPath { $p } begin", x => s"bPath end: ${x.show}"): p match case s @ Select(Value.MemberRef(sym, _), Tree.Ident("Unit")) if sym is ctx.builtinSym.runtimeSym.get => - bPath(Value.UnitLit(false))(k) + bPath(Value.Lit(Tree.UnitLit(false)))(k) case s @ DynSelect(qual, fld, arrayIdx) => bErrStop(msg"Unsupported dynamic selection") case s @ Select(qual, name) => @@ -504,15 +504,15 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): Node.Case(e, casesList, defaultCase) case Return(res) => bResult(res)(x => Node.Result(Ls(x))) case Throw(Instantiate(false, Select(Value.SimpleRef(_), ident), - Ls(Arg(N, Value.Lit(Tree.StrLit(e), _))) :: Nil)) + Ls(Arg(N, Value.Lit(Tree.StrLit(e)))) :: Nil)) if ident.name === "Error" => Node.Panic(e) case Throw(Instantiate(false, Select(Value.MemberRef(_, _), ident), - Ls(Arg(N, Value.Lit(Tree.StrLit(e), _))) :: Nil)) + Ls(Arg(N, Value.Lit(Tree.StrLit(e)))) :: Nil)) if ident.name === "Error" => Node.Panic(e) case Throw(Instantiate(false, Select(Value.This(_), ident), - Ls(Arg(N, Value.Lit(Tree.StrLit(e), _))) :: Nil)) + Ls(Arg(N, Value.Lit(Tree.StrLit(e)))) :: Nil)) if ident.name === "Error" => Node.Panic(e) case Label(label, loop, body, rest) => TODO("Label not supported") diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 93023a03fb..69e93b7f3f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -1030,7 +1030,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: errExtra: => Str, )(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr => Expr = fld match - case Value.Lit(IntLit(value), _) if value.isValidInt => + case Value.Lit(IntLit(value)) if value.isValidInt => val idx = value.toInt tupleRef => if idx >= 0 then i32.const(idx) @@ -1191,11 +1191,11 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ) def result(r: codegen.Result)(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = r match - case Value.Lit(BoolLit(value), _) => + case Value.Lit(BoolLit(value)) => ref.i31(i32.const(if value then 1 else 0)) - case Value.Lit(IntLit(value), _) => + case Value.Lit(IntLit(value)) => withValidIntLit(value, r.toLoc)(intVal => ref.i31(i32.const(intVal))) - case Value.Lit(StrLit(value), _) => + case Value.Lit(StrLit(value)) => val lit = internStringLiteral(value) val stringCtor = getOrLoadStrCtorFunction call( @@ -1428,10 +1428,10 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case Select(Value.This(sym), id) if (sym eq State.globalThisSymbol) && id.name == "Error" => return as.headOption match case S(arg) => arg.value match - case Value.Lit(BoolLit(value), _) => ref.i31(i32.const(if value then 1 else 0)) - case Value.Lit(IntLit(value), _) => + case Value.Lit(BoolLit(value)) => ref.i31(i32.const(if value then 1 else 0)) + case Value.Lit(IntLit(value)) => withValidIntLit(value, arg.value.toLoc)(intVal => ref.i31(i32.const(intVal))) - case Value.Lit(StrLit(_), _) => result(arg.value) + case Value.Lit(StrLit(_)) => result(arg.value) case unsupported => warnExpr( msg"WatBuilder::result for Instantiate(...) of `globalThis.Error(...)` with payload `${ diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index 87ab1a5d11..98424e1d2d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -416,7 +416,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C */ private def throwMatchErrorBlock = Throw(Instantiate(mut = false, Select(State.globalThisSymbol.asThis, Tree.Ident("Error"))(S(ctx.builtins.Error)), - (Value.StrLit("match error").asArg :: Nil) :: Nil)) // TODO add failed-match scrutinee info + (Value.Lit(syntax.Tree.StrLit("match error")).asArg :: Nil) :: Nil)) // TODO add failed-match scrutinee info import syntax.Keyword.{`if`, `while`} diff --git a/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls b/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls index 031cf9ef58..97a3436100 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls @@ -24,9 +24,7 @@ //│ ╙── ^ //│ —————————————| Lowered IR Tree |———————————————————————————————————————————————————————————————————— //│ Program: -//│ main = Return of Lit: -//│ lit = IntLit of 2 -//│ erasedType = Primitive of Int +//│ main = Return of Lit of IntLit of 2 //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— //│ = 2 @@ -55,12 +53,8 @@ print("Hi") //│ argss = Ls of //│ Ls of //│ Arg: -//│ value = Lit: -//│ lit = StrLit of "Hi" -//│ erasedType = Primitive of Str -//│ rest = Return of Lit: -//│ lit = IntLit of 2 -//│ erasedType = Primitive of Int +//│ value = Lit of StrLit of "Hi" +//│ rest = Return of Lit of IntLit of 2 //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— //│ > Hi //│ = 2 diff --git a/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls b/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls index cc14429e8e..feaaf5a5e4 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls @@ -27,9 +27,7 @@ x + 1 //│ main = Scoped(syms = {x⁰}): //│ body = Assign: //│ lhs = x⁰ -//│ rhs = Lit: -//│ lit = IntLit of 1 -//│ erasedType = Primitive of Int +//│ rhs = Lit of IntLit of 1 //│ rest = Return of Call: //│ fun = SimpleRef of builtin:+⁰ //│ argss = Ls of @@ -37,9 +35,7 @@ x + 1 //│ Arg: //│ value = SimpleRef of x⁰ //│ Arg: -//│ value = Lit: -//│ lit = IntLit of 1 -//│ erasedType = Primitive of Int +//│ value = Lit of IntLit of 1 //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— //│ let x⁰; set x⁰ = 1; return +⁰(x⁰, 1) //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— From df6d0654f6b4797d87199000112e8349cb891b05 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 28 May 2026 15:19:16 +0800 Subject: [PATCH 04/48] codegen: Add `MemberRef` to erasedType --- hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 4ff2d41b82..ccd12bafaf 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -1058,6 +1058,10 @@ enum Value extends Path with ProductWithExtraInfo: /** Returns the [[`ErasedType`]] of this value. */ def erasedType(using Ctx): ErasedType = this match + case MemberRef(_, disamb: (ClassSymbol | ModuleOrObjectSymbol)) => ErasedType.AnyRef(false, disamb) + case MemberRef(_, disamb: TypeAliasSymbol) => + // TODO(Derppening): Do we preserve the `TypeAliasSymbol` here? + disamb.irClsLikeDefn.flatMap(_.sym.asClsOrMod).fold(ErasedType.objectRef)(ErasedType.AnyRef(false, _)) case This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => ErasedType.AnyRef(false, clsOrMod) case Lit(lit) => lit.erasedType case _ => ErasedType.objectRef From fbe4104160967fe9909d661ec7cd6f62d5ae9fd7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 29 May 2026 14:11:19 +0800 Subject: [PATCH 05/48] codegen: Add `erasedType` to `SimpleRef` --- .../src/main/scala/hkmc2/codegen/Block.scala | 50 +++--- .../scala/hkmc2/codegen/BlockSimplifier.scala | 30 ++-- .../hkmc2/codegen/BlockTransformer.scala | 4 +- .../scala/hkmc2/codegen/BlockTraverser.scala | 2 +- .../hkmc2/codegen/BufferableTransform.scala | 13 +- .../scala/hkmc2/codegen/DeadParamElim.scala | 2 +- .../scala/hkmc2/codegen/EtaExpansion.scala | 2 +- .../FirstClassFunctionTransformer.scala | 30 +++- .../scala/hkmc2/codegen/HandlerLowering.scala | 30 ++-- .../src/main/scala/hkmc2/codegen/Lifter.scala | 28 ++-- .../main/scala/hkmc2/codegen/Lowering.scala | 142 +++++++++--------- .../main/scala/hkmc2/codegen/Printer.scala | 2 +- .../codegen/ReflectionInstrumenter.scala | 10 +- .../hkmc2/codegen/SpecializedSwitch.scala | 2 +- .../hkmc2/codegen/StackSafeTransform.scala | 10 +- .../scala/hkmc2/codegen/SymbolRefresher.scala | 6 +- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 26 ++-- .../scala/hkmc2/codegen/UsedVarAnalyzer.scala | 2 +- .../scala/hkmc2/codegen/WorkerWrapper.scala | 2 +- .../hkmc2/codegen/deforest/Rewrite.scala | 10 +- .../codegen/flowAnalysis/FlowAnalysis.scala | 20 +-- .../scala/hkmc2/codegen/js/JSBuilder.scala | 14 +- .../scala/hkmc2/codegen/llir/Builder.scala | 16 +- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 12 +- .../hkmc2/semantics/ucs/Normalization.scala | 12 +- .../test/mlscript/block-staging/Functions.mls | 2 +- .../test/mlscript/codegen/BlockPrinter.mls | 6 +- 27 files changed, 257 insertions(+), 228 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index ccd12bafaf..7cdb507d16 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -545,7 +545,7 @@ object HandleBlock: N, Nil, S(par), handlerMtds, Nil, Nil, // Apparently, the lifter is not happy with any assignment in the preCtor... - Assign(State.noSymbol, Call(State.builtinOpsMap("super").asSimpleRef, args.map(_.asArg) ne_:: Nil)(true, true, false), End()), + Assign(State.noSymbol, Call(State.builtinOpsMap("super").asSimpleRef(N), args.map(_.asArg) ne_:: Nil)(true, true, false), End()), End(), N, N, @@ -556,7 +556,7 @@ object HandleBlock: .define(clsDefn) .assign(lhs, Instantiate(mut = true, clsDefn.sym.asMemberRef(cls), Nil :: Nil)) .define(bodyDefn) - .assign(res, handleSuspension(lhs.asSimpleRef, bodyDefn.sym.asMemberRef(bodyDefn.dSym))) + .assign(res, handleSuspension(lhs.asSimpleRef(ErasedType.AnyRef(false, cls)), bodyDefn.sym.asMemberRef(bodyDefn.dSym))) .rest(rest) def apply( @@ -878,7 +878,7 @@ enum ErasedType: /** Trait representing a Block IR element that has an [[`ErasedType`]]. */ trait HasErasedType: /** The [[`ErasedType`]] of this element, or `N` if the erased type is not known. */ - def erasedType: Opt[ErasedType] + val erasedType: Opt[ErasedType] /** Similar to `erasedType`, but coerces to the top type if the specific erased type is not known. */ def erasedType_!(using Ctx): ErasedType = erasedType.getOrElse(ErasedType.objectRef) @@ -899,7 +899,7 @@ sealed abstract class Result extends AutoLocated: // def extraInfo: Str = toLoc.toString def showDbg(using DebugPrinter): Str = this match - case Value.SimpleRef(l) => l.showAsPlain + case Value.SimpleRef(l, _) => l.showAsPlain case Value.MemberRef(l, disamb) => s"${l.showAsPlain}${s"‹${disamb.showAsPlain}›"}" case Value.This(sym) => s"this[${sym.showAsPlain}]" case Value.Lit(lit) => lit.idStr @@ -919,7 +919,7 @@ sealed abstract class Result extends AutoLocated: q.isPure && sel.symbol.exists(_.isPure) case c @ Call(fun, ass) if c.isKnownUnsaturatedCall => fun.isPure && ass.forall(_.forall(a => a.spread.isEmpty && a.value.isPure)) - case Call(Value.SimpleRef(bs: BuiltinSymbol), ass) if bs.isPure => + case Call(Value.SimpleRef(bs: BuiltinSymbol, _), ass) if bs.isPure => ass.forall(_.forall(_.value.isPure)) case Record(mut, args) => args.forall(_.value.isPure) case Tuple(mut, elems) => elems.forall(_.value.isPure) @@ -938,7 +938,7 @@ sealed abstract class Result extends AutoLocated: case Lambda(params, body) => Vector.single(params) case Tuple(mut, elems) => elems.iterator.map(_.value).toVector case Record(mut, elems) => elems.iterator.map(_.value).toVector - case Value.SimpleRef(l) => Vector.empty + case Value.SimpleRef(l, _) => Vector.empty case Value.MemberRef(bms, disamb) => Vector.empty case Value.This(sym) => Vector.empty case Value.Lit(lit) => Vector.single(lit) @@ -960,7 +960,7 @@ sealed abstract class Result extends AutoLocated: case Tuple(mut, elems) => elems.flatMap(_.value.freeVars).toSet case Record(mut, args) => args.flatMap(arg => arg.idx.fold(Set.empty)(_.freeVars) ++ arg.value.freeVars).toSet - case Value.SimpleRef(l) => Set(l) + case Value.SimpleRef(l, _) => Set(l) case Value.MemberRef(bms, _) => Set(bms) case Value.This(sym) => Set.empty case Value.Lit(lit) => Set.empty @@ -974,12 +974,12 @@ sealed abstract class Result extends AutoLocated: case Tuple(mut, elems) => elems.flatMap(_.value.freeVarsLLIR).toSet case Record(mut, args) => args.flatMap(arg => arg.idx.fold(Set.empty)(_.freeVarsLLIR) ++ arg.value.freeVarsLLIR).toSet - case Value.SimpleRef(l: (BuiltinSymbol | TermSymbol)) => Set.empty - case Value.SimpleRef(l: DefinitionSymbol[?]) => l.defn match + case Value.SimpleRef(l: (BuiltinSymbol | TermSymbol), _) => Set.empty + case Value.SimpleRef(l: DefinitionSymbol[?], _) => l.defn match case Some(d: ClassLikeDef) => Set.empty case Some(d: TermDefinition) if d.companionClass.isDefined => Set.empty case _ => Set(l) - case Value.SimpleRef(l) => Set(l) + case Value.SimpleRef(l, _) => Set(l) case Value.MemberRef(l: (ClassSymbol | TermSymbol), disamb) => Set.empty case Value.MemberRef(l, disamb) => disamb.defn match case Some(d: ClassLikeDef) => Set.empty @@ -1047,8 +1047,8 @@ case class Select(qual: Path, name: Tree.Ident)(val symbol: Opt[DefinitionSymbol case class DynSelect(qual: Path, fld: Path, arrayIdx: Bool) extends Path -enum Value extends Path with ProductWithExtraInfo: - case SimpleRef(sym: LocalVarSymbol | BuiltinSymbol) +enum Value extends Path with HasErasedType with ProductWithExtraInfo: + case SimpleRef(sym: LocalVarSymbol | BuiltinSymbol, _erasedType: Opt[ErasedType]) /** * @param disamb The symbol disambiguating the definition that the reference refers to. */ @@ -1056,15 +1056,16 @@ enum Value extends Path with ProductWithExtraInfo: case This(sym: InnerSymbol) case Lit(lit: Literal) - /** Returns the [[`ErasedType`]] of this value. */ - def erasedType(using Ctx): ErasedType = this match - case MemberRef(_, disamb: (ClassSymbol | ModuleOrObjectSymbol)) => ErasedType.AnyRef(false, disamb) + /** The [[`ErasedType`]] of this value. */ + val erasedType: Opt[ErasedType] = this match + case SimpleRef(_, erasedType) => erasedType + case MemberRef(_, disamb: (ClassSymbol | ModuleOrObjectSymbol)) => S(ErasedType.AnyRef(false, disamb)) case MemberRef(_, disamb: TypeAliasSymbol) => // TODO(Derppening): Do we preserve the `TypeAliasSymbol` here? - disamb.irClsLikeDefn.flatMap(_.sym.asClsOrMod).fold(ErasedType.objectRef)(ErasedType.AnyRef(false, _)) - case This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => ErasedType.AnyRef(false, clsOrMod) - case Lit(lit) => lit.erasedType - case _ => ErasedType.objectRef + disamb.irClsLikeDefn.flatMap(_.sym.asClsOrMod).map(ErasedType.AnyRef(false, _)) + case This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => S(ErasedType.AnyRef(false, clsOrMod)) + case Lit(lit) => S(lit.erasedType) + case _ => N override def extraInfo(using DebugPrinter): Str = this match case MemberRef(bms, disamb) => s"disamb=${disamb.showAsPlain}" @@ -1079,7 +1080,7 @@ object Value: extension (r: RefLike) def symbol: Symbol = r match - case SimpleRef(l) => l + case SimpleRef(l, _) => l case MemberRef(bms, _) => bms case This(sym) => sym @@ -1087,7 +1088,7 @@ object Value: object Ref: def apply(l: Local, disamb: Opt[DefinitionSymbol[?]]): Value.RefLike = l match - case l: (LocalVarSymbol | BuiltinSymbol) => l.asSimpleRef + case l: (LocalVarSymbol | BuiltinSymbol) => l.asSimpleRef(N) case bms: BlockMemberSymbol => bms.asMemberRef: disamb.getOrElse: lastWords(s"Cannot disambiguate overloaded member symbol ${bms.nme}: no disambiguation provided") @@ -1102,7 +1103,7 @@ object Value: def apply(l: TempSymbol | VarSymbol | BuiltinSymbol): Value.RefLike = Ref(l, N) def unapply(v: Value): Opt[(Local, Opt[DefinitionSymbol[?]])] = v match - case SimpleRef(l) => S(l -> N) + case SimpleRef(l, _) => S(l -> N) case MemberRef(bms, disamb) => S(bms -> S(disamb)) case This(sym) => S(sym -> N) case _ => N @@ -1140,7 +1141,10 @@ extension (k: Block => Block) def blockBuilder: Block => Block = identity extension (s: (LocalVarSymbol | BuiltinSymbol)) - inline def asSimpleRef: Value.SimpleRef = Value.SimpleRef(s) + @deprecated("Use the overload accepting `Opt[ErasedType]` instead.") + inline def asSimpleRef: Value.SimpleRef = Value.SimpleRef(s, N) + inline def asSimpleRef(erasedType: ErasedType): Value.SimpleRef = Value.SimpleRef(s, S(erasedType)) + inline def asSimpleRef(erasedType: Opt[ErasedType]): Value.SimpleRef = Value.SimpleRef(s, erasedType) extension (bms: BlockMemberSymbol) inline def asMemberRef(disamb: DefinitionSymbol[?]): Value.MemberRef = Value.MemberRef(bms, disamb) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala index 3292609efd..a3cc2de9c7 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala @@ -129,7 +129,7 @@ class BlockSimplifier case ts: TermSymbol => usedPrivateFields += ts case _ => - case Value.SimpleRef(loc) => + case Value.SimpleRef(loc, _) => usedVars += loc case Value.MemberRef(loc, _) => usedVars += loc @@ -216,7 +216,7 @@ class BlockSimplifier override def applyValue(v: Value)(k: Value => Block) = v match // * Replace with `undefined` those references to local variables that are never assigned - case Value.SimpleRef(loc) if localVars.contains(loc) && !definedVars.contains(loc) => + case Value.SimpleRef(loc, _) if localVars.contains(loc) && !definedVars.contains(loc) => registerChange(s"${loc.showDbg} is never assigned; replacing read with undefined") // if !symbolsToPreserve(loc) then removedLocals += loc k(Value.Lit(syntax.Tree.UnitLit(false))) @@ -503,7 +503,7 @@ class BlockSimplifier // * Discard local variables that are assigned just to be returned // * Note: the reason we do this here and not in DeadCodeElim is that we need to check `capturedVars` - case Assign(lhs: LocalVar, rhs, Return(Value.SimpleRef(ret))) + case Assign(lhs: LocalVar, rhs, Return(Value.SimpleRef(ret, _))) if !inDryRun && (ret is lhs) && !capturedVars(lhs) && !symbolsToPreserve(lhs) => registerChange(s"tail-return ${lhs.showDbg} ~> ${rhs.showDbg}") @@ -513,7 +513,7 @@ class BlockSimplifier // log(s"Propagating ${lhs} := ${rhs} (${assignedResults.get(lhs)})") assignedResults += lhs -> Assigned(ass, rhs.match - case r @ Value.SimpleRef(sym: LocalVar) => + case r @ Value.SimpleRef(sym: LocalVar, _) => if capturedVars(sym) then N else val rhs2 = assignedResults(sym) @@ -640,9 +640,9 @@ class BlockSimplifier if gaveUp then Set.empty else p match - case Value.SimpleRef(r: LocalVar) if capturedVars(r) => + case Value.SimpleRef(r: LocalVar, _) if capturedVars(r) => giveUp - case Value.SimpleRef(r: LocalVar) => + case Value.SimpleRef(r: LocalVar, _) => assignedResults.get(r).fold(giveUp)(getShapesA) case Value.MemberRef(r, sym: ModuleOrObjectSymbol) => Set.single(sym) @@ -724,7 +724,7 @@ class BlockSimplifier override def applyValue(v: Value)(k: Value => Block): Block = v match - case Value.SimpleRef(loc: LocalVar) if !inDryRun && !capturedVars(loc) => + case Value.SimpleRef(loc: LocalVar, _) if !inDryRun && !capturedVars(loc) => val rs = assignedResults(loc) // log(s"Ref ${loc.showDbg} ${rs} ${localVars(loc)} ${capturedVars(loc)}") @@ -765,7 +765,7 @@ class BlockSimplifier case _ => litValue = false opt match - case S((r @ Value.SimpleRef(lv: LocalVar)) -> rhs) => + case S((r @ Value.SimpleRef(lv: LocalVar, _)) -> rhs) => if assignedResults(lv) is rhs then Set.single(r) ++ analyzeValues(rhs) else Set.empty @@ -813,7 +813,7 @@ class BlockSimplifier case call: Call if call.isKnownUnsaturatedCall && call.isPure => S(call) case _ => opt match - case S((Value.SimpleRef(next: LocalVar), nextAsst)) + case S((Value.SimpleRef(next: LocalVar, _), nextAsst)) if !capturedVars(next) && !seen(next) && (assignedResults(next) is nextAsst) => loop(nextAsst, seen + next) case _ => N @@ -829,7 +829,7 @@ class BlockSimplifier r match // * Try to propagate pure calls - case Value.SimpleRef(loc: LocalVar) if !inDryRun && !capturedVars(loc) => + case Value.SimpleRef(loc: LocalVar, _) if !inDryRun && !capturedVars(loc) => assignedPureCallPrefix(loc) match case S(call) => registerChange(s"${loc.showDbg} ~> ${call.showDbg}") @@ -838,7 +838,7 @@ class BlockSimplifier super.applyResult(r)(k) // * Try to combine pure calls (typically unsaturated calls) assigned to a variable into the current call - case c @ Call(Value.SimpleRef(loc: LocalVar), argss) if !inDryRun && !capturedVars(loc) => + case c @ Call(Value.SimpleRef(loc: LocalVar, _), argss) if !inDryRun && !capturedVars(loc) => assignedPureCallPrefix(loc) match case S(prefix) => registerChange(s"${loc.showDbg} call prefix ~> ${prefix.showDbg}") @@ -849,13 +849,13 @@ class BlockSimplifier case N => super.applyResult(r)(k) // * Remove uses of the strange builtin comma operator - case Call(Value.SimpleRef(sym: BuiltinSymbol), (arg1 :: arg2 :: Nil) :: Nil) + case Call(Value.SimpleRef(sym: BuiltinSymbol, _), (arg1 :: arg2 :: Nil) :: Nil) if sym.nme === "," && arg1.spread.isEmpty && arg2.spread.isEmpty => Assign.discard(arg1.value, k(arg2.value)) // * Partially evaluate calls to known builtins with literal arguments - case Call(Value.SimpleRef(sym: BuiltinSymbol), args :: Nil) if args.forall(_.value.isInstanceOf[Value]) => + case Call(Value.SimpleRef(sym: BuiltinSymbol, _), args :: Nil) if args.forall(_.value.isInstanceOf[Value]) => val argValues = args.map(_.value.asInstanceOf[Value]) args.foreach(a => assert(a.spread.isEmpty)) builtinEval.lift((sym.nme, argValues)) match @@ -1183,10 +1183,10 @@ class BlockSimplifier val copier = Copier(resSym, mapping) val newBlk = copier.applyBlock(blk) if extraArgss.isEmpty then - acc(Scoped(Set.single(resSym), newBlk(k(resSym.asSimpleRef)))) + acc(Scoped(Set.single(resSym), newBlk(k(resSym.asSimpleRef(N))))) else acc(Scoped(Set(resSym), newBlk( - k(Call(resSym.asSimpleRef, extraArgss.ne_!)(c.isMlsFun, c.mayRaiseEffects, false))))) + k(Call(resSym.asSimpleRef(N), extraArgss.ne_!)(c.isMlsFun, c.mayRaiseEffects, false))))) case (sym, value) :: argRest => val newSym = VarSymbol(sym.id) go(acc.assignScoped(newSym, value), argRest, mapping + (sym -> newSym)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala index c30ebbddfc..b30d299dbf 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala @@ -178,12 +178,12 @@ class BlockTransformer(subst: SymbolSubst): case v: Value => applyValue(v)(k) def applyValue(v: Value)(k: Value => Block) = v match - case Value.SimpleRef(l) => + case Value.SimpleRef(l, _) => val l2 = applyLocal(l) match case l: (LocalVarSymbol | BuiltinSymbol) => l case l2 => lastWords(s"Expected applyValue on `$l` (${l.getClass.getSimpleName}) to create a symbol of the same type, but got `$l2` (${l2.getClass.getSimpleName})") - k(if (l2 is l) then v else l2.asSimpleRef.withLocOf(v)) + k(if (l2 is l) then v else l2.asSimpleRef(N).withLocOf(v)) case Value.MemberRef(bms, disamb) => val bms2 = bms.subst val disamb2 = disamb.subst diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala index 24f44aa1ee..33b04fd3af 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala @@ -71,7 +71,7 @@ class BlockTraverser: case v: Value => applyValue(v) def applyValue(v: Value): Unit = v match - case Value.SimpleRef(l) => l.traverse + case Value.SimpleRef(l, _) => l.traverse case Value.MemberRef(bms, disamb) => bms.traverse disamb.traverse diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala index bd1f989f1b..9e1bf3b175 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala @@ -36,14 +36,17 @@ class BufferableTransform()(using Ctx, State, Raise): Param(p.flags, varMap(p.sym), p.sign, p.modulefulness) (params.map(pl => ParamList(pl.flags, pl.params.map(mapParam), pl.restParam.map(mapParam))), varMap.toMap) def mkFieldReplacer(buf: VarSymbol, baseIdx: VarSymbol, symMap: Map[Symbol, Symbol]) = + val baseIdxRef = baseIdx.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)) def getOffset(off: Int)(k: Path => Block): Block = val idxSymbol = new TempSymbol(N, "idx") - Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), - k(DynSelect(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true)))) + val idxSymbolRef = idxSymbol.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)) + Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef(N), (baseIdxRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), + k(DynSelect(buf.asSimpleRef(N).selSN("buf"), idxSymbolRef, true)))) def assignToOffset(off: Int, r: Result, rst: Block) = val idxSymbol = new TempSymbol(N, "idx") - Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), - AssignDynField(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true, r, applyBlock(rst)))) + val idxSymbolRef = idxSymbol.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)) + Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef(N), (baseIdxRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), + AssignDynField(buf.asSimpleRef(N).selSN("buf"), idxSymbolRef, true, r, applyBlock(rst)))) new BlockTransformer(SymbolSubst.Id): override def applyLocal(sym: Local): Local = symMap.getOrElse(sym, sym) override def applyBlock(b: Block): Block = b match @@ -79,7 +82,7 @@ class BufferableTransform()(using Ctx, State, Raise): val blk = mkFieldReplacer(buf, idx, symMap).applyBlock(f.body) FunDefn(f.owner, f.sym, TermSymbol(f.dSym.k, f.dSym.owner, f.dSym.id), PlainParamList( Param(FldFlags.empty, buf, N, Modulefulness.none) :: Param(FldFlags.empty, idx, N, Modulefulness.none) :: Nil) :: newParams, - if isCtor then Begin(blk, Return(idx.asSimpleRef)) else blk)(configOverride = f.configOverride, annotations = f.annotations) + if isCtor then Begin(blk, Return(idx.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)))) else blk)(configOverride = f.configOverride, annotations = f.annotations) val fakeCtor = transformFunDefn(FunDefn.withFreshSymbol( S(companionSym), BlockMemberSymbol("ctor", Nil, false), diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala index 2b75982a87..61aa63d63c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala @@ -237,7 +237,7 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise): case _ => super.applyPath(p)(k) override def applyValue(v: Value)(k: Value => Block): Block = v match - case ref@Value.SimpleRef(l: VarSymbol) if activeEliminatedParams(l) => + case ref@Value.SimpleRef(l: VarSymbol, _) if activeEliminatedParams(l) => k(Value.Lit(Tree.UnitLit(false)).withLocOf(ref)) case _ => super.applyValue(v)(k) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala index 7bcca2730a..ca274b3ea1 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala @@ -175,7 +175,7 @@ class EtaExpansionRewrite(val etaExpansionSolver: EtaExpansionSolver)(using Rais Param.simple(new VarSymbol(new Tree.Ident(s"eta$$$idx$$$i"))) EtaParamList( ParamList(ParamListFlags.empty, params, N), - params.map(p => Arg(N, p.sym.asSimpleRef)), + params.map(p => Arg(N, p.sym.asSimpleRef(N))), ) else lastWords("not the same shape?") diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala index dd22e929d8..7af1152a95 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala @@ -28,11 +28,11 @@ class FirstClassFunctionTransformer ) val defSym = new BlockMemberSymbol("Function$", Nil, false) val callDef = FunDefn.withFreshSymbol(Some(clsSym), new BlockMemberSymbol("call", Nil, true), params :: Nil, - Return(Call(p, params.params.map(_.sym.asSimpleRef.asArg) ne_:: Nil)(true, false, false)))(N, annotations = Nil) + Return(Call(p, params.params.map(_.sym.asSimpleRef(N).asArg) ne_:: Nil)(true, false, false)))(N, annotations = Nil) ClsLikeDefn(None, clsSym, defSym, None, syntax.Cls, None, Nil, Some(Select(State.globalThisSymbol.asThis, Tree.Ident("Function"))(Some(ctx.builtins.Function))), callDef :: Nil, Nil, Nil, Assign.discard( - Call(State.builtinOpsMap("super").asSimpleRef, Nil ne_:: Nil)(false, false, false), + Call(State.builtinOpsMap("super").asSimpleRef(N), Nil ne_:: Nil)(false, false, false), End()), End(), None, None)(N, annotations = Nil) private def getParamList(l: BlockMemberSymbol): Option[ParamList] = funDefns.get(l) match @@ -49,7 +49,17 @@ class FirstClassFunctionTransformer val clsDef = generateFCFunctionClass(ref, params) val tmp = new TempSymbol(None) val cls = clsDef.sym.asMemberRef(clsDef.isym) - Scoped(Set(clsDef.sym, tmp), Define(clsDef, Assign(tmp, Instantiate(false, cls, Nil :: Nil), k(tmp.asSimpleRef)))) + Scoped( + syms = Set(clsDef.sym, tmp), + body = Define( + clsDef, + Assign( + tmp, + Instantiate(false, cls, Nil :: Nil), + k(tmp.asSimpleRef(ErasedType.AnyRef(rsc = false, clsDef.isym.asClsOrMod.get))), + ) + ) + ) case _ => k(p) case sel: Select => sel.symbol match case Some(s: TermSymbol) if (s.k is syntax.Fun) => @@ -62,7 +72,17 @@ class FirstClassFunctionTransformer val clsDef = generateFCFunctionClass(sel, params) val tmp = new TempSymbol(None) val cls = clsDef.sym.asMemberRef(clsDef.isym) - Scoped(Set(clsDef.sym, tmp), Define(clsDef, Assign(tmp, Instantiate(false, cls, Nil :: Nil), k(tmp.asSimpleRef)))) + Scoped( + Set(clsDef.sym, tmp), + Define( + clsDef, + Assign( + tmp, + Instantiate(false, cls, Nil :: Nil), + k(tmp.asSimpleRef(ErasedType.AnyRef(rsc = false, clsDef.isym.asClsOrMod.get))), + ) + ) + ) case Some(_) => k(p) case _ => raise(ErrorReport(msg"Cannot determine if ${sel.name.name} is a function." -> sel.toLoc :: Nil, @@ -80,7 +100,7 @@ class FirstClassFunctionTransformer case c @ Call(fun, argss) => applyListOf(argss, (args, k2) => applyArgs(args)(k2)): argss2 => def call(f: Path) = Call(f, argss2.ne_!)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall) fun match - case ref @ Value.SimpleRef(sym) => sym match + case ref @ Value.SimpleRef(sym, N) => sym match case _: VarSymbol | _: TempSymbol => k(call(ref.selSN("call"))) case _ => k(call(fun)) case ref @ Value.MemberRef(_, _) => k(call(fun)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index 11263f63b2..e9b4e50f97 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -92,7 +92,7 @@ object HandlerLowering: import HandlerLowering.* class HandlerPaths(using Elaborator.State): - val runtimePath: Path = State.runtimeSymbol.asSimpleRef + val runtimePath: Path = State.runtimeSymbol.asSimpleRef(N) val contClsPath: Path = runtimePath.selSN("FunctionContFrame").selSN("class") val mkEffectPath: Path = runtimePath.selSN("mkEffect") val handleBlockImplPath: Path = runtimePath.selSN("handleBlockImpl") @@ -138,18 +138,18 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, object StateTransition: private val transitionSymbol = freshTmp("transition") def apply(uid: StateId) = - Return(PureCall(transitionSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid))))) + Return(PureCall(transitionSymbol.asSimpleRef(N), List(Value.Lit(Tree.IntLit(uid))))) def unapply(blk: Block) = blk match - case Return(PureCall(Value.SimpleRef(`transitionSymbol`), List(Value.Lit(Tree.IntLit(uid))))) => + case Return(PureCall(Value.SimpleRef(`transitionSymbol`, N), List(Value.Lit(Tree.IntLit(uid))))) => S(uid) case _ => N object Unwind: private val unwindSymbol = freshTmp("unwind") def apply(uid: StateId, loc: Value) = - Return(PureCall(unwindSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid)), loc))) + Return(PureCall(unwindSymbol.asSimpleRef(N), List(Value.Lit(Tree.IntLit(uid)), loc))) def unapply(blk: Block) = blk match - case Return(PureCall(Value.SimpleRef(`unwindSymbol`), List(Value.Lit(Tree.IntLit(uid)), loc: Value))) => + case Return(PureCall(Value.SimpleRef(`unwindSymbol`, N), List(Value.Lit(Tree.IntLit(uid)), loc: Value))) => S(uid, loc) case _ => N @@ -540,9 +540,9 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val debugInfoSym = freshTmp(s"$debugNme$$debugInfo") // TODO: properly support spread argument by calculating the correct length. val rtArgLists = intLit(fun.params.length) :: fun.params.flatMap: pl => - intLit(pl.params.length) :: pl.params.map(p => p.sym.asSimpleRef) + intLit(pl.params.length) :: pl.params.map(p => p.sym.asSimpleRef(N)) val newCtx = HandlerCtx.FunctionLike(FunctionCtx(funcPath, thisPath, ResumeInfo(rtArgLists, varList, L(fun.sym)), - DebugInfo(debugNme, if opt.debug then debugInfoSym.asSimpleRef else unit), thisPath.isDefined && fun.params.isEmpty)) + DebugInfo(debugNme, if opt.debug then debugInfoSym.asSimpleRef(N) else unit), thisPath.isDefined && fun.params.isEmpty)) val bod2 = translateBlock(fun.body, newCtx, scopedVars) val fun2 = if fun.body is bod2 then fun else FunDefn(fun.owner, fun.sym, fun.dSym, fun.params, bod2)(fun.configOverride, fun.annotations) @@ -658,7 +658,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val headTransformed = segmentTailTransform.applyBlock(parts.states(head).blk) val initial: Block => Block = blk => Match( - pcVar.asSimpleRef, + pcVar.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)), Case.Lit(Tree.IntLit(head)) -> headTransformed :: Nil, N, blk @@ -670,7 +670,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val transformed = transformState(uid) blk => Match( - pcVar.asSimpleRef, + pcVar.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)), Case.Lit(Tree.IntLit(uid)) -> transformed :: Nil, N, acc(blk) @@ -691,17 +691,17 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, def getSaved(off: BigInt): (Block => Block, Path) = if off == 0 then return (id, DynSelect(paths.runtimePath.selSN("resumeArr"), paths.runtimePath.selSN("resumeIdx"), true)) - val addOne = Assign(getSavedTmp, Call(State.builtinOpsMap("+").asSimpleRef, (paths.runtimePath.selSN("resumeIdx").asArg :: intLit(off).asArg :: Nil) ne_:: Nil)(false, false, false), _) - (addOne, DynSelect(paths.runtimePath.selSN("resumeArr"), getSavedTmp.asSimpleRef, true)) + val addOne = Assign(getSavedTmp, Call(State.builtinOpsMap("+").asSimpleRef(N), (paths.runtimePath.selSN("resumeIdx").asArg :: intLit(off).asArg :: Nil) ne_:: Nil)(false, false, false), _) + (addOne, DynSelect(paths.runtimePath.selSN("resumeArr"), getSavedTmp.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)), true)) - val resumeArrIndexed = DynSelect(paths.runtimePath.selSN("resumeArr"), getSavedTmp.asSimpleRef, true) - val plus = State.builtinOpsMap("+").asSimpleRef + val resumeArrIndexed = DynSelect(paths.runtimePath.selSN("resumeArr"), getSavedTmp.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)), true) + val plus = State.builtinOpsMap("+").asSimpleRef(N) val preRestore = blockBuilder .assign(pcVar, paths.resumePc) .scopedVars(Set(getSavedTmp)) val restoreVars = vars.zipWithIndex.foldLeft(preRestore): case (builder, (local, idx)) => builder - .assign(getSavedTmp, if idx == 0 then paths.resumeIdx else Call(plus, (getSavedTmp.asSimpleRef.asArg :: intLit(1).asArg :: Nil) ne_:: Nil)(false, false, false)) + .assign(getSavedTmp, if idx == 0 then paths.resumeIdx else Call(plus, (getSavedTmp.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)).asArg :: intLit(1).asArg :: Nil) ne_:: Nil)(false, false, false)) .assign(local, resumeArrIndexed) Scoped( @@ -737,7 +737,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case EffectfulResult(r) => // Fallback case, this may lead to unnecessary assignments if it is assign-like val l = freshTmp() - Scoped(Set(l), effectCheck(l, r, k(l.asSimpleRef))) + Scoped(Set(l), effectCheck(l, r, k(l.asSimpleRef(N)))) case _ => super.applyResult(r)(k) topLevelTransform.applyBlock(b) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index 56a802ebc3..ee06f97bda 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -387,7 +387,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): case Some(value) => syms.addOne(FunSyms(l, d) -> value) value - k(newSym.asSimpleRef) + k(newSym.asSimpleRef(N)) // Naked reference to a parameterized class constructor (used as a first-class function). // Replace with a partially applied curried C$ wrapper. @@ -404,7 +404,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): case Some(value) => syms.addOne(FunSyms(l, d) -> value) value - k(newSym.asSimpleRef) + k(newSym.asSimpleRef(N)) case _ => resolveDefnRef(l, d, ctor) match case Some(value) => k(value) @@ -557,7 +557,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val vd = ValDefn( tSym, fldSym, - varSym.asSimpleRef + varSym.asSimpleRef(N) )(N, Nil) (sym -> varSym, p, vd) @@ -833,7 +833,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): */ sealed trait GenericRewrittenScope[T] extends RewrittenScope[T]: lazy val captureSym = VarSymbol(Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath = captureSym.asSimpleRef + override lazy val capturePath = captureSym.asSimpleRef(N) protected val liftedObjsOrdered: List[InnerSymbol] = node.liftedObjSyms.toList.sortBy(_.uid) protected val liftedObjsSyms: Map[InnerSymbol, VarSymbol] = liftedObjsOrdered.map: s => s -> VarSymbol(Tree.Ident(s.nme + "$")) @@ -848,7 +848,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): */ sealed trait ClsLikeRewrittenScope[T](sym: InnerSymbol) extends RewrittenScope[T]: lazy val captureSym = TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath = captureSym.asSimpleRef + override lazy val capturePath = captureSym.asSimpleRef(N) protected val liftedObjsOrdered: List[InnerSymbol] = node.liftedObjSyms.toList.sortBy(_.uid) protected val liftedObjsSyms: Map[InnerSymbol, TermSymbol] = liftedObjsOrdered.map: s => s -> TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(s.nme + "$")) @@ -915,7 +915,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): with ClsLikeRewrittenScope[ClsLikeDefn](obj.cls.isym): private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath: Path = captureSym.asSimpleRef + override lazy val capturePath: Path = captureSym.asSimpleRef(N) override def rewriteImpl: LifterResult[ClsLikeDefn] = val rewriterCtor = new BlockRewriter @@ -938,7 +938,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): with ClsLikeRewrittenScope[ClsLikeBody](obj.clsBody.isym): private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.clsBody.isym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath: Path = captureSym.asSimpleRef + override lazy val capturePath: Path = captureSym.asSimpleRef(N) override def rewriteImpl: LifterResult[ClsLikeBody] = val rewriterCtor = new BlockRewriter @@ -966,7 +966,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): .toMap override protected val passedSymsMap = passedSymsMap_.view.mapValues(_.asLocalPath).toMap - override protected val capSymsMap = capSymsMap_.view.mapValues(s => s.asSimpleRef).toMap + override protected val capSymsMap = capSymsMap_.view.mapValues(s => s.asSimpleRef(N)).toMap override protected val passedDefnsMap = defnSymsMap_.view.mapValues(_.asDefnRef).toMap val auxParams: List[Param] = @@ -1012,10 +1012,10 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): case Nil => lastWords("tried to make an aux defn for a function with no parameter list") val args = restSym match case Some(value) => - val tail = Arg(S(SpreadKind.Eager), value.asSimpleRef) :: Nil + val tail = Arg(S(SpreadKind.Eager), value.asSimpleRef(N)) :: Nil syms.foldLeft(tail): - case (acc, sym) => Arg(N, sym.asSimpleRef) :: acc - case None => syms.map(s => Arg(N, s.asSimpleRef)) + case (acc, sym) => Arg(N, sym.asSimpleRef(N)) :: acc + case None => syms.map(s => Arg(N, s.asSimpleRef(N))) val call = Call(fun.sym.asMemberRef(fun.dSym), args ne_:: Nil)(true, true, false) val bod = Return(call) @@ -1065,7 +1065,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): with ClsLikeRewrittenScope[ClsLikeDefn](obj.cls.isym): private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath: Path = captureSym.asSimpleRef + override lazy val capturePath: Path = captureSym.asSimpleRef(N) private val passedSymsMap_ : Map[Local, (vs: VarSymbol, ts: TermSymbol)] = passedSymsOrdered.map: s => s -> @@ -1131,8 +1131,8 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): // Uses the symbols from pl1. def applyPlToPl(pl1: ParamList, pl2: ParamList): List[Arg] = (pl1.restParam, pl2.restParam) match - case (S(rp), S(_)) => pl1.params.foldRight(Arg(S(SpreadKind.Eager), rp.sym.asSimpleRef) :: Nil)((p, ls) => p.sym.asSimpleRef.asArg :: ls) - case (N, N) => pl1.paramSyms.map(s => s.asSimpleRef.asArg) + case (S(rp), S(_)) => pl1.params.foldRight(Arg(S(SpreadKind.Eager), rp.sym.asSimpleRef(N)) :: Nil)((p, ls) => p.sym.asSimpleRef(N).asArg :: ls) + case (N, N) => pl1.paramSyms.map(s => s.asSimpleRef(N).asArg) case _ => die // If class has a main param list, the aux list comes after it diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 4cd5db7f7b..496e44b5db 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -57,7 +57,7 @@ class LoweringCtx( Subst(map + kv) */ def apply(v: Value): Value = v match - case Value.SimpleRef(l) => map.getOrElse(l, v) + case Value.SimpleRef(l, _) => map.getOrElse(l, v) case _ => v object LoweringCtx: def loweringCtx(using sub: LoweringCtx): LoweringCtx = sub @@ -116,7 +116,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): private def wasmIntrinsicPath(sym: BuiltinSymbol, unary: Bool): Opt[Path] = if config.target is CompilationTarget.Wasm then val map = if unary then wasmUnaryIntrinsicMap else wasmBinaryIntrinsicMap - map.get(sym.nme).map(name => State.wasmSymbol.asSimpleRef.selN(Tree.Ident(name))) + map.get(sym.nme).map(name => State.wasmSymbol.asSimpleRef(N).selN(Tree.Ident(name))) else N private lazy val wasmIntrinsicSymbols: Set[BlockMemberSymbol] = Set( ctx.builtins.wasm.plus_impl, @@ -136,10 +136,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ) lazy val unreachableFn = - Select(State.runtimeSymbol.asSimpleRef, Tree.Ident("unreachable"))(S(State.unreachableSymbol)) + Select(State.runtimeSymbol.asSimpleRef(N), Tree.Ident("unreachable"))(S(State.unreachableSymbol)) def unit: Path = - Select(State.runtimeSymbol.asSimpleRef, Tree.Ident("Unit"))(S(State.unitSymbol)) + Select(State.runtimeSymbol.asSimpleRef(N), Tree.Ident("Unit"))(S(State.unitSymbol)) // type Rcd = (mut: Bool, args: List[RcdArg]) // * Better, but Scala's patmat exhaustiveness chokes on it @@ -155,7 +155,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): source = Diagnostic.Source.Compilation ) lowerSuperCtorCall( - State.builtinOpsMap("super").asSimpleRef, + State.builtinOpsMap("super").asSimpleRef(N), isMlsFun = true, isTailCall = false, args.headOption, @@ -415,7 +415,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case acc: NELs[Ls[Arg]] => val tmp = loweringCtx.registerTempSymbol(N, "baseCall") val call = Call(fr, acc)(isMlsFun, true, isTailCall).withLoc(loc) - Assign(tmp, call, lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall, loc)(k)) + Assign(tmp, call, lowerRemainingCalls(tmp.asSimpleRef(N), args, remainingArgss, isTailCall, loc)(k)) case (_ :: _, Nil) => k(Call(fr, acc.reverse.ne_!)(isMlsFun, true, isTailCall).withLoc(loc)) fr.targetSymbol match @@ -435,7 +435,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case args :: remainingArgss => val tmp = loweringCtx.registerTempSymbol(N, "callPrefix") Assign(tmp, call, - lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall, loc)(k)) + lowerRemainingCalls(tmp.asSimpleRef(N), args, remainingArgss, isTailCall, loc)(k)) /** Lower an instantiation with multiple argument lists into `Instantiate` and `Call` nodes, * trying to group as many as possible into a single `Instantiate` @@ -459,7 +459,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case (Nil, args :: remainingArgss) => val tmp = loweringCtx.registerTempSymbol(N, "baseInst") Assign(tmp, buildInstantiate(acc.reverse), - lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall = false, N)(k)) + lowerRemainingCalls(tmp.asSimpleRef(N), args, remainingArgss, isTailCall = false, N)(k)) case (remainingParamss, Nil) => // * Eta-expand missing argument lists by creating lambdas for each remaining param list. // * This makes partial `new C(args...)` explicit instead of relying on the JS class curry. @@ -471,7 +471,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): softTODO(ps.restParam.isEmpty, "Eta expanding rest parameters in constructor definitions is not yet supported") val freshParams = (ps.params zip freshSyms).map((p, s) => Param(p.flags, s, N, p.modulefulness)) val freshParamList = ParamList(ps.flags, freshParams, N) - val freshArgs = freshSyms.map(s => Arg(N, s.asSimpleRef)) + val freshArgs = freshSyms.map(s => Arg(N, s.asSimpleRef(N))) Lambda(freshParamList, Return(etaExpand(rest, accArgss :+ freshArgs)))(Nil) k(etaExpand(remainingParamss, acc.reverse)) // * Resolve the class definition to get the constructor param lists. @@ -498,7 +498,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case remainingArgss => val tmp = loweringCtx.registerTempSymbol(N, "baseInst") Assign(tmp, buildInstantiate(as :: Nil), - lowerRemainingCalls(tmp.asSimpleRef, remainingArgss.head, remainingArgss.tail, isTailCall = false, N)(k)) + lowerRemainingCalls(tmp.asSimpleRef(N), remainingArgss.head, remainingArgss.tail, isTailCall = false, N)(k)) else zipArgs(ctorParamLists, args, Nil) def lowerArgs(arg: Term)(k: Ls[Arg] => Block)(using LoweringCtx): Block = @@ -616,7 +616,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): warnStmt (sym, disamb) match case (sym: (LocalVarSymbol | BuiltinSymbol), _) => - k(loweringCtx(sym.asSimpleRef.withLocOf(ref))) + k(loweringCtx(sym.asSimpleRef(N).withLocOf(ref))) case (sym: BlockMemberSymbol, _) => k(loweringCtx(sym.asMemberRef(disamb.orElse(sym.asPrincipal).get).withLocOf(ref))) case (sym: InnerSymbol, _) => @@ -668,13 +668,13 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Assign( isContinue, Call( - State.builtinOpsMap("===").asSimpleRef, - (bodyResult.asSimpleRef.asArg :: State.runtimeSymbol.asSimpleRef.selSN("Continue").asArg :: Nil) ne_:: Nil, + State.builtinOpsMap("===").asSimpleRef(N), + (bodyResult.asSimpleRef(N).asArg :: State.runtimeSymbol.asSimpleRef(N).selSN("Continue").asArg :: Nil) ne_:: Nil, )(true, false, false), Match( - isContinue.asSimpleRef, + isContinue.asSimpleRef(ErasedType.Primitive(PrimitiveType.Bool)), (Case.Lit(Tree.BoolLit(true)) -> Continue(label)) :: Nil, - S(Assign(result, bodyResult.asSimpleRef, Break(label))), + S(Assign(result, bodyResult.asSimpleRef(N), Break(label))), End("label continue-sentinel dispatch") ) ) @@ -683,7 +683,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): label, loop = hasLocalContinue(body), bodyBlock, - k(result.asSimpleRef) + k(result.asSimpleRef(N)) ) case st.Break(label, result, value) => value match @@ -711,7 +711,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ErrorReport( msg"Expected arguments for builtin operator '${sym.nme}'" -> t.toLoc :: Nil, S(arg), source = Diagnostic.Source.Compilation) - k(sym.asSimpleRef.withLocOf(ref)) + k(sym.asSimpleRef(N).withLocOf(ref)) case st.Tup(Fld(FldFlags.benign(), arg, N) :: Nil) => if !sym.unary then raise: ErrorReport( @@ -719,7 +719,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): source = Diagnostic.Source.Compilation) subTerm(arg): ar => val target = wasmIntrinsicPath(sym, unary = true) - .getOrElse(sym.asSimpleRef.withLocOf(ref)) + .getOrElse(sym.asSimpleRef(N).withLocOf(ref)) k(Call(target, (Arg(N, ar) :: Nil) ne_:: Nil)(true, false, false)) case st.Tup(Fld(FldFlags.benign(), arg1, N) :: Fld(FldFlags.benign(), arg2, N) :: Nil) => if !sym.binary then raise: @@ -741,7 +741,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ar1, (Case.Lit(posLit) -> term_nonTail(arg2)(Assign(ts, _, End()))) :: Nil, S(Assign(ts, Value.Lit(negLit), End())), - k(ts.asSimpleRef), + k(ts.asSimpleRef(N)), ) sym match case State.andSymbol => mkBooleanMatch(true) @@ -749,7 +749,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case _ => subTerm_nonTail(arg2): ar2 => val target = wasmIntrinsicPath(sym, unary = false) - .getOrElse(sym.asSimpleRef.withLocOf(ref)) + .getOrElse(sym.asSimpleRef(N).withLocOf(ref)) k(Call(target, (Arg(N, ar1) :: Arg(N, ar2) :: Nil) ne_:: Nil)(true, false, false)) case _ => fail: ErrorReport( @@ -789,21 +789,21 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): instantiated match case t if instantiatedResolvedBms.exists(_ is ctx.builtins.js.bitand) => - conclude(State.runtimeSymbol.asSimpleRef.selN(Tree.Ident("bitand"))) + conclude(State.runtimeSymbol.asSimpleRef(N).selN(Tree.Ident("bitand"))) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.js.bitnot) => - conclude(State.runtimeSymbol.asSimpleRef.selN(Tree.Ident("bitnot"))) + conclude(State.runtimeSymbol.asSimpleRef(N).selN(Tree.Ident("bitnot"))) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.js.bitor) => - conclude(State.runtimeSymbol.asSimpleRef.selN(Tree.Ident("bitor"))) + conclude(State.runtimeSymbol.asSimpleRef(N).selN(Tree.Ident("bitor"))) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.js.shl) => - conclude(State.runtimeSymbol.asSimpleRef.selN(Tree.Ident("shl"))) + conclude(State.runtimeSymbol.asSimpleRef(N).selN(Tree.Ident("shl"))) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.js.try_catch) => - conclude(State.runtimeSymbol.asSimpleRef.selN(Tree.Ident("try_catch"))) + conclude(State.runtimeSymbol.asSimpleRef(N).selN(Tree.Ident("try_catch"))) case t if t.resolvedSym.exists { case sym: BlockMemberSymbol => wasmIntrinsicSymbols.contains(sym) case _ => false } => val sym = t.resolvedSym.get.asInstanceOf[BlockMemberSymbol] - conclude(State.wasmSymbol.asSimpleRef.selN(Tree.Ident(sym.nme))) + conclude(State.wasmSymbol.asSimpleRef(N).selN(Tree.Ident(sym.nme))) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.debug.printStack) => if !config.effectHandlers.exists(_.debug) then return fail: @@ -811,7 +811,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): msg"Debugging functions are not enabled" -> t.toLoc :: Nil, source = Diagnostic.Source.Compilation) - conclude(State.runtimeSymbol.asSimpleRef.selSN("raisePrintStackEffect").withLocOf(baseF)) + conclude(State.runtimeSymbol.asSimpleRef(N).selSN("raisePrintStackEffect").withLocOf(baseF)) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.scope.locally) => // scope.locally only applies to the innermost call; extra args are applied on top if allArgs.length > 1 then @@ -867,7 +867,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): subTerms(as): asr => HandleBlock(lhs, resSym, par, asr, cls, handlers, inScopedBlock(returnedTerm(bod)), - k(resSym.asSimpleRef)) + k(resSym.asSimpleRef(N))) case st.Blk(sts, res) => block(sts, R(res), inStmtPos = inStmtPos)(k) case Assgn(lhs, rhs) => lhs match @@ -961,7 +961,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): TryBlock( subTerm_nonTail(sub)(p => Assign(l, p, End())), subTerm_nonTail(finallyDo)(_ => End()), - k(l.asSimpleRef) + k(l.asSimpleRef(N)) ) case Quoted(body) => quote(body)(k) @@ -1012,13 +1012,13 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): // subTerm(t)(k) def setupTerm(name: Str, args: Ls[Path])(k: Result => Block)(using LoweringCtx): Block = - k(Instantiate(mut = false, State.termSymbol.asSimpleRef.selSN(name), args.map(_.asArg) :: Nil)) + k(Instantiate(mut = false, State.termSymbol.asSimpleRef(N).selSN(name), args.map(_.asArg) :: Nil)) def setupQuotedKeyword(kw: Str): Path = - State.termSymbol.asSimpleRef.selSN("Keyword").selSN(kw) + State.termSymbol.asSimpleRef(N).selSN("Keyword").selSN(kw) def setupSymbol(symbol: Local)(k: Result => Block)(using LoweringCtx): Block = - k(Instantiate(mut = false, State.termSymbol.asSimpleRef.selSN("Symbol"), + k(Instantiate(mut = false, State.termSymbol.asSimpleRef(N).selSN("Symbol"), (Value.Lit(Tree.StrLit(symbol.nme)).asArg :: Nil) :: Nil)) def quotePattern(p: FlatPattern)(k: Result => Block)(using LoweringCtx): Block = p match @@ -1037,20 +1037,20 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): blockBuilder.assign(l1, r1) .chain(b => quotePattern(pattern)(r2 => Assign(l2, r2, b))) .chain(b => quoteSplit(continuation)(r3 => Assign(l3, r3, b))) - .chain(b => setupTerm("Branch", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef))(r4 => Assign(l4, r4, b))) + .chain(b => setupTerm("Branch", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef(N)))(r4 => Assign(l4, r4, b))) .chain(b => quoteSplit(tail)(r5 => Assign(l5, r5, b))) - .rest(setupTerm("Cons", (l4 :: l5 :: Nil).map(s => s.asSimpleRef))(k)) + .rest(setupTerm("Cons", (l4 :: l5 :: Nil).map(s => s.asSimpleRef(N)))(k)) case Split.Let(sym, term, tail) => setupSymbol(sym): r1 => loweringCtx.collectScopedSym(sym) val l1, l2, l3 = loweringCtx.registerTempSymbol(N) blockBuilder.assign(l1, r1) - .chain(b => setupTerm("Ref", l1.asSimpleRef :: Nil)(r => Assign(sym, r, b))) + .chain(b => setupTerm("Ref", l1.asSimpleRef(N) :: Nil)(r => Assign(sym, r, b))) .chain(b => quote(term)(r2 => Assign(l2, r2, b))) .chain(b => quoteSplit(tail)(r3 => Assign(l3, r3, b))) - .rest(setupTerm("Let", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef))(k)) + .rest(setupTerm("Let", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef(N)))(k)) case Split.Else(default) => quote(default): r => val l = loweringCtx.registerTempSymbol(N) - Assign(l, r, setupTerm("Else", l.asSimpleRef :: Nil)(k)) + Assign(l, r, setupTerm("Else", l.asSimpleRef(N) :: Nil)(k)) case Split.End => setupTerm("End", Nil)(k) case Split.LetSplit(sym, tail) => setupSymbol(sym): r1 => loweringCtx.collectScopedSym(sym) @@ -1064,7 +1064,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): lazy val setupFilename: Path = val state = summon[State] - state.importSymbol.asSimpleRef.selSN("meta").selSN("url") + state.importSymbol.asSimpleRef(N).selSN("meta").selSN("url") def quote(t: st)(k: Result => Block)(using LoweringCtx): Block = t match case Lit(lit) => @@ -1075,14 +1075,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Resolved(Ref(sym), disamb) => sym match case sym: BlockMemberSymbol => k(sym.asMemberRef(disamb)) - case sym: (LocalVarSymbol | BuiltinSymbol) => k(sym.asSimpleRef) + case sym: (LocalVarSymbol | BuiltinSymbol) => k(sym.asSimpleRef(N)) case sym => lastWords(s"Unexpected symbol kind ${sym.getClass.getSimpleName}: $sym") case Ref(sym) => k(sym.asPath) case SynthSel(Ref(sym: ModuleOrObjectSymbol), name) => // Local cross-stage references setupSymbol(sym): r1 => val l1, l2 = loweringCtx.registerTempSymbol(N) - Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.Lit(syntax.Tree.UnitLit(false)) :: Nil)(r2 => - Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) + Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef(N) :: setupFilename :: Value.Lit(syntax.Tree.UnitLit(false)) :: Nil)(r2 => + Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef(N) :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) )) case SynthSel(Ref(sym: BlockMemberSymbol), name) => // Multi-file cross-stage references if config.qqEnabled then fail: @@ -1097,8 +1097,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val basePath = base.up val targetPath = filename val relPath = targetPath.relativeTo(basePath).map(_.toString).getOrElse(targetPath.toString) - Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.Lit(syntax.Tree.StrLit(relPath)) :: Nil)(r2 => - Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) + Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef(N) :: setupFilename :: Value.Lit(syntax.Tree.StrLit(relPath)) :: Nil)(r2 => + Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef(N) :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) )) case _ => fail: ErrorReport( @@ -1114,13 +1114,13 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Assign( arr, Tuple(mut = false, ds.reverse.map(_.asArg)), - Assign(l, r, setupTerm("Lam", arr.asSimpleRef :: l.asSimpleRef :: Nil)(k))) + Assign(l, r, setupTerm("Lam", arr.asSimpleRef(N) :: l.asSimpleRef(N) :: Nil)(k))) case sym :: rest => loweringCtx.collectScopedSym(sym) setupSymbol(sym): r => val l = loweringCtx.registerTempSymbol(N) - Assign(l, r, setupTerm("Ref", l.asSimpleRef :: Nil): r1 => - Assign(sym, r1, rec(rest, l.asSimpleRef :: ds)(k))) + Assign(l, r, setupTerm("Ref", l.asSimpleRef(N) :: Nil): r1 => + Assign(sym, r1, rec(rest, l.asSimpleRef(N) :: ds)(k))) rec(params.params.map(_.sym), Nil)(k) // TODO: restParam? case App(lhs, Tup(rhs)) => quote(lhs): r1 => def rec(es: Ls[Elem], xs: Ls[Path])(k: Result => Block): Block = es match @@ -1129,14 +1129,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Assign( arrSym, Tuple(mut = false, xs.reverse.map(_.asArg)), - setupTerm("Tup", arrSym.asSimpleRef :: Nil): r2 => + setupTerm("Tup", arrSym.asSimpleRef(N) :: Nil): r2 => val l1 = loweringCtx.registerTempSymbol(N) val l2 = loweringCtx.registerTempSymbol(N) - Assign(l1, r1, Assign(l2, r2, setupTerm("App", l1.asSimpleRef :: l2.asSimpleRef :: Nil)(k))) + Assign(l1, r1, Assign(l2, r2, setupTerm("App", l1.asSimpleRef(N) :: l2.asSimpleRef(N) :: Nil)(k))) ) case Fld(_, t, _) :: rest => quote(t): r2 => val l = loweringCtx.registerTempSymbol(N) - Assign(l, r2, rec(rest, l.asSimpleRef :: xs)(k)) + Assign(l, r2, rec(rest, l.asSimpleRef(N) :: xs)(k)) case Spd(eager, term) :: rest => fail: ErrorReport( @@ -1152,17 +1152,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N) val arrSym = loweringCtx.registerTempSymbol(N, "arr") blockBuilder.assign(l1, r1) - .chain(b => setupTerm("Ref", l1.asSimpleRef :: Nil)(r => Assign(sym, r, b))) + .chain(b => setupTerm("Ref", l1.asSimpleRef(N) :: Nil)(r => Assign(sym, r, b))) .chain(b => quote(rhs)(r2 => Assign(l2, r2, b))) .chain(b => quote(res)(r3 => Assign(l3, r3, b))) - .chain(b => setupTerm("LetDecl", l1.asSimpleRef :: Nil)(r4 => Assign(l4, r4, b))) - .chain(b => setupTerm("DefineVar", l1.asSimpleRef :: l2.asSimpleRef :: Nil)(r5 => Assign(l5, r5, b))) - .assign(arrSym, Tuple(mut = false, (l4 :: l5 :: Nil).map(s => s.asSimpleRef.asArg))) - .rest(setupTerm("Blk", arrSym.asSimpleRef :: l3.asSimpleRef :: Nil)(k)) + .chain(b => setupTerm("LetDecl", l1.asSimpleRef(N) :: Nil)(r4 => Assign(l4, r4, b))) + .chain(b => setupTerm("DefineVar", l1.asSimpleRef(N) :: l2.asSimpleRef(N) :: Nil)(r5 => Assign(l5, r5, b))) + .assign(arrSym, Tuple(mut = false, (l4 :: l5 :: Nil).map(s => s.asSimpleRef(N).asArg))) + .rest(setupTerm("Blk", arrSym.asSimpleRef(N) :: l3.asSimpleRef(N) :: Nil)(k)) } case IfLike(_, IfLikeForm.ReturningIf, split) => quoteSplit(split.getExpandedSplit): r => val l = loweringCtx.registerTempSymbol(N) - Assign(l, r, setupTerm("IfLike", setupQuotedKeyword("If") :: l.asSimpleRef :: Nil)(k)) + Assign(l, r, setupTerm("IfLike", setupQuotedKeyword("If") :: l.asSimpleRef(N) :: Nil)(k)) case Unquoted(body) => term(body)(k) case _ => fail: ErrorReport( @@ -1233,7 +1233,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Assign( rcdSym, Record(mut = false, fsr.reverse), - k((Arg(N, rcdSym.asSimpleRef) :: asr).reverse))) + k((Arg(N, rcdSym.asSimpleRef(N)) :: asr).reverse))) inline def plainArgs(ts: Ls[st])(k: Ls[Arg] => Block)(using LoweringCtx): Block = @@ -1262,7 +1262,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Define(lamDef, k(lamDef.asPath)) case r => val l = loweringCtx.registerTempSymbol(N) - Assign(l, r, k(l |> Value.SimpleRef.apply)) + Assign(l, r, k(l.asSimpleRef(N))) def program(main: st.Blk): Program = @@ -1416,12 +1416,12 @@ trait LoweringSelSanityChecks(using Config, TL, Raise, State) blockBuilder .assign(selRes, Select(p, nme)(disamb)) .assign(State.noSymbol, Select(p, Tree.Ident(nme.name+"$__checkNotMethod"))(N)) - .ifthen(selRes.asSimpleRef, + .ifthen(selRes.asSimpleRef(N), Case.Lit(syntax.Tree.UnitLit(false)), Throw(Instantiate(mut = false, Select(State.globalThisSymbol.asThis, Tree.Ident("Error"))(N), (Value.Lit(syntax.Tree.StrLit(s"Access to required field '${nme.name}' yielded 'undefined'")).asArg :: Nil) :: Nil)) ) - .rest(k(selRes.asSimpleRef)) + .rest(k(selRes.asSimpleRef(N))) @@ -1442,9 +1442,9 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) extension (k: Block => Block) def |>: (b: Block): Block = k(b) - private val traceLogFn = State.runtimeSymbol.asSimpleRef.selSN("TraceLogger").selSN("log") - private val traceLogIndentFn = State.runtimeSymbol.asSimpleRef.selSN("TraceLogger").selSN("indent") - private val traceLogResetFn = State.runtimeSymbol.asSimpleRef.selSN("TraceLogger").selSN("resetIndent") + private val traceLogFn = State.runtimeSymbol.asSimpleRef(N).selSN("TraceLogger").selSN("log") + private val traceLogIndentFn = State.runtimeSymbol.asSimpleRef(N).selSN("TraceLogger").selSN("indent") + private val traceLogResetFn = State.runtimeSymbol.asSimpleRef(N).selSN("TraceLogger").selSN("resetIndent") private val strConcatFn = selFromGlobalThis("String", "prototype", "concat", "call") private val inspectFn = selFromGlobalThis("util", "inspect") @@ -1477,34 +1477,34 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) val psSymArgs = psInspectedSyms.zipWithIndex.foldRight[Ls[Arg]](Arg(N, Value.Lit(Tree.StrLit(")"))) :: Nil): case (((s, p), i), acc) => if i == psInspectedSyms.length - 1 - then Arg(N, s.asSimpleRef) :: acc - else Arg(N, s.asSimpleRef) :: Arg(N, Value.Lit(Tree.StrLit(", "))) :: acc + then Arg(N, s.asSimpleRef(ErasedType.Primitive(PrimitiveType.Str))) :: acc + else Arg(N, s.asSimpleRef(ErasedType.Primitive(PrimitiveType.Str))) :: Arg(N, Value.Lit(Tree.StrLit(", "))) :: acc val tmp1, tmp2, tmp3 = loweringCtx.registerTempSymbol(N) assignStmts(psInspectedSyms.map: (pInspectedSym, pSym) => - pInspectedSym -> pureCall(inspectFn, Arg(N, pSym.asSimpleRef) :: Nil) + pInspectedSym -> pureCall(inspectFn, Arg(N, pSym.asSimpleRef(N)) :: Nil) *) |>: assignStmts( enterMsgSym -> pureCall( strConcatFn, Arg(N, Value.Lit(Tree.StrLit(s"CALL ${name.getOrElse("[arrow function]")}("))) :: psSymArgs ), - tmp1 -> pureCall(traceLogFn, Arg(N, enterMsgSym.asSimpleRef) :: Nil), + tmp1 -> pureCall(traceLogFn, Arg(N, enterMsgSym.asSimpleRef(ErasedType.Primitive(PrimitiveType.Str))) :: Nil), prevIndentLvlSym -> pureCall(traceLogIndentFn, Nil) ) |>: term(bod)(r => assignStmts( resSym -> r, - resInspectedSym -> pureCall(inspectFn, Arg(N, resSym.asSimpleRef) :: Nil), + resInspectedSym -> pureCall(inspectFn, Arg(N, resSym.asSimpleRef(N)) :: Nil), retMsgSym -> pureCall( strConcatFn, - Arg(N, Value.Lit(Tree.StrLit("=> "))) :: Arg(N, resInspectedSym.asSimpleRef) :: Nil + Arg(N, Value.Lit(Tree.StrLit("=> "))) :: Arg(N, resInspectedSym.asSimpleRef(ErasedType.Primitive(PrimitiveType.Str))) :: Nil ), - tmp2 -> pureCall(traceLogResetFn, Arg(N, prevIndentLvlSym.asSimpleRef) :: Nil), - tmp3 -> pureCall(traceLogFn, Arg(N, retMsgSym.asSimpleRef) :: Nil) + tmp2 -> pureCall(traceLogResetFn, Arg(N, prevIndentLvlSym.asSimpleRef(N)) :: Nil), + tmp3 -> pureCall(traceLogFn, Arg(N, retMsgSym.asSimpleRef(ErasedType.Primitive(PrimitiveType.Str))) :: Nil) ) |>: - Ret(resSym.asSimpleRef) + Ret(resSym.asSimpleRef(N)) ) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index 7b80670ce8..a815632645 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -164,7 +164,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): else doc def print(value: Value)(using Scope): Document = value match - case Value.SimpleRef(l) => print(l) + case Value.SimpleRef(l, _) => print(l) case Value.MemberRef(bms, disamb) => showSymbol(bms.nme, S(disamb)) case Value.This(sym) if sym === State.globalThisSymbol => showSymbol(sym.nme, S(sym.asDefnSym)) case Value.This(sym) => doc"${print(sym)}.this" diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala index 62c13f56ea..0f85824f2f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala @@ -55,7 +55,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n def assign(res: Result, symName: Str = "tmp")(k: Path => Block): Block = // TODO: skip assignment if res: Path? val sym = new TempSymbol(N, symName) - Scoped(Set(sym), Assign(sym, res, k(sym.asSimpleRef))) + Scoped(Set(sym), Assign(sym, res, k(sym.asSimpleRef(N)))) def tuple(elems: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = assign(Tuple(false, elems.map(asArg)), symName)(k) @@ -66,8 +66,8 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n // helpers for instrumenting Block - def blockMod(name: Str) = summon[State].blockSymbol.asSimpleRef.selSN(name) - def optionMod(name: Str) = summon[State].optionSymbol.asSimpleRef.selSN(name) + def blockMod(name: Str) = summon[State].blockSymbol.asSimpleRef(N).selSN(name) + def optionMod(name: Str) = summon[State].optionSymbol.asSimpleRef(N).selSN(name) def blockCtor(name: Str, args: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = call(blockMod(name), args, true, symName)(k) @@ -169,7 +169,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n // rulePath ctx.get(p).map(k).getOrElse: p match - case Value.SimpleRef(l) => + case Value.SimpleRef(l, _) => transformSymbol(l): sym => blockCtor("ValueSimpleRef", Ls(sym), "var")(k) case Value.MemberRef(bms, disamb) => @@ -213,7 +213,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n raise(ErrorReport(msg"Instantiate with multiple argument lists not supported in staged module." -> r.toLoc :: Nil)) End() // desugar Runtime.Tuple.get into Select - case Call(fun, Ls(Arg(_, scrut), Arg(_, Value.Lit(Tree.IntLit(idx)))) :: _) if fun == Value.SimpleRef(State.runtimeSymbol).selSN("Tuple").selSN("get") => + case Call(fun, Ls(Arg(_, scrut), Arg(_, Value.Lit(Tree.IntLit(idx)))) :: _) if fun == Value.SimpleRef(State.runtimeSymbol, N).selSN("Tuple").selSN("get") => transformPath(Select(scrut, Tree.Ident(idx.toString()))(N))(k) case Call(fun, argss) => val stagedFunPath = fun match diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala index 55b88d3c89..f424306c89 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala @@ -279,7 +279,7 @@ private def matchChainToSwitch(m: MatchChain): SwitchLike = object SpecializedSwitch: def unapply(b: Block) = b match - case m @ Match(scrut = r @ Value.SimpleRef(l)) => + case m @ Match(scrut = r @ Value.SimpleRef(l, _)) => val chain = findMatchChainRec(m, r, Nil) val SwitchLike(scrut, cases, dflt, rest) = matchChainToSwitch(chain) if cases.size < 2 then N diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala index bb6ab15766..073e55ac09 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala @@ -12,7 +12,7 @@ import hkmc2.codegen.HandlerLowering.FnOrCls class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: StackSafetyMap)(using State, Config): private val STACK_DEPTH_IDENT: Tree.Ident = Tree.Ident("stackDepth") - private val runtimePath: Path = State.runtimeSymbol.asSimpleRef + private val runtimePath: Path = State.runtimeSymbol.asSimpleRef(N) private val checkDepthPath: Path = runtimePath.selN(Tree.Ident("checkDepth")) private val runStackSafePath: Path = runtimePath.selN(Tree.Ident("runStackSafe")) private val stackDepthPath: Path = runtimePath.selN(STACK_DEPTH_IDENT) @@ -20,7 +20,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S private def intLit(n: BigInt) = Value.Lit(Tree.IntLit(n)) private def op(op: String, a: Path, b: Path) = - Call(State.builtinOpsMap(op).asSimpleRef, (a.asArg :: b.asArg :: Nil) ne_:: Nil)(true, false, false) + Call(State.builtinOpsMap(op).asSimpleRef(N), (a.asArg :: b.asArg :: Nil) ne_:: Nil)(true, false, false) // Increases the stack depth, assigns the call to a value, then decreases the stack depth // then binds that value to a desired block @@ -30,7 +30,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S else blockBuilder .assign(sym, res) - .assignFieldN(runtimePath, STACK_DEPTH_IDENT, curDepth.asSimpleRef) + .assignFieldN(runtimePath, STACK_DEPTH_IDENT, curDepth.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int))) .rest(f(sym.asPath)) def wrapStackSafe(body: Block, resSym: Local, rest: Block) = @@ -48,7 +48,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S // Rewrites anything that can contain a Call to increase the stack depth def transform(b: Block, curDepth: => LocalVarSymbol, isTopLevel: Bool = false): Block = def usesStack(r: Result) = r match - case Call(Value.SimpleRef(_: BuiltinSymbol), _) => false + case Call(Value.SimpleRef(_: BuiltinSymbol, _), _) => false case c: Call if !c.mayRaiseEffects => false // a call can only trigger a stack delay if it can raise effects case _: Call | _: Instantiate => true case _ => false @@ -98,7 +98,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S new BlockTraverserShallow: applyBlock(b) override def applyResult(r: Result): Unit = r match - case Call(Value.SimpleRef(_: BuiltinSymbol), _) => () + case Call(Value.SimpleRef(_: BuiltinSymbol, _), _) => () case _: Call | _: Instantiate => trivial = false case _ => () trivial diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala index da41b40027..fadc239889 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala @@ -206,10 +206,10 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends case _ => super.applyPath(p)(k) override def applyValue(v: Value)(k: Value => Block): Block = v match - case Value.SimpleRef(l) => + case Value.SimpleRef(l, _) => mapping.get(l) match case Some(newSym: (LocalVarSymbol | BuiltinSymbol)) => - k(newSym.asSimpleRef) + k(newSym.asSimpleRef(N)) case _ => super.applyValue(v)(k) case Value.MemberRef(bms, disamb) => mapping.get(bms) match @@ -219,7 +219,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends case Some(nd) => lastWords(s"unexpected symbol kind for disamb: ${nd}") case N => lastWords(s"unexpected lack of refreshed disamb symbol for $disamb") k(newBms.asMemberRef(newDisamb)) - case Some(newSym: (LocalVarSymbol | TempSymbol)) => k(newSym.asSimpleRef) + case Some(newSym: (LocalVarSymbol | TempSymbol)) => k(newSym.asSimpleRef(N)) case _ => super.applyValue(v)(k) case Value.This(sym) => mapping.get(sym) match diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 3bcf8a89ba..213ed666b6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -370,13 +370,13 @@ class TailRecOpt(using State, TL, Raise): // `assignedSyms` contains all of the param symbols that have been assigned to // before the current assignment, and thus references to them must be rewritten // to point to a temporary variable. - case Value.SimpleRef(l: VarSymbol) => assignedSyms.get(l) match + case Value.SimpleRef(l: VarSymbol, _) => assignedSyms.get(l) match case S(v) => val tmpSym = v.force_! // Adding this to `requiredTmps` will make sure we set the temporary variable // to the current variable at the start of the rewritten call. requiredTmps += (l, tmpSym) - k(tmpSym.asSimpleRef) + k(tmpSym.asSimpleRef(N)) case _ => super.applyValue(v)(k) case _ => super.applyValue(v)(k) @@ -387,7 +387,7 @@ class TailRecOpt(using State, TL, Raise): val selfAssigns = argListResults.flatMap: (_, thisParamSyms, args, argsRes) => argsRes match case CallArgsResult.Success(res) => thisParamSyms.zip(res).collect: - case (sym1, Value.SimpleRef(sym2)) if sym1 === sym2 => sym1 + case (sym1, Value.SimpleRef(sym2, _)) if sym1 === sym2 => sym1 case CallArgsResult.ForceSpread => List.empty assignedSyms --= selfAssigns @@ -413,7 +413,7 @@ class TailRecOpt(using State, TL, Raise): // in `rewrite`. Also note that `paramRewriter` will add all encountered rewritten variables // to `requiredTmps`. val ret = paramRewriter.applyResult(res)(Assign(sym, _, acc)) match - case Assign(sym, Value.SimpleRef(sym1), rest) if sym === sym1 => rest // avoid useless assignments + case Assign(sym, Value.SimpleRef(sym1, _), rest) if sym === sym1 => rest // avoid useless assignments case x => x ret case CallArgsResult.ForceSpread => @@ -433,7 +433,7 @@ class TailRecOpt(using State, TL, Raise): // Main args def mainArgs(rest: List[Path]) = (0 until paramList.size).toList.foldRight(rest): - case (n, acc) => DynSelect(tupleSym.asSimpleRef, Value.Lit(Tree.IntLit(n)), true) :: acc + case (n, acc) => DynSelect(tupleSym.asSimpleRef(N), Value.Lit(Tree.IntLit(n)), true) :: acc // If the rest param exists, append a slice val (initialBlk: (Block => Block), pathList: List[Path]) = @@ -441,10 +441,10 @@ class TailRecOpt(using State, TL, Raise): val sliceResSym = TempSymbol(N, "sliceRes") // runtime.Tuple.slice(tupleSym, paramList.length, 0) val sliceRes = Call( - State.runtimeSymbol.asSimpleRef + State.runtimeSymbol.asSimpleRef(N) .sel(Tree.Ident("Tuple"), State.tupleSymbol) .sel(Tree.Ident("slice"), State.tupleSliceSymbol), - (tupleSym.asSimpleRef.asArg + (tupleSym.asSimpleRef(N).asArg :: Value.Lit(Tree.IntLit(paramList.length)).asArg :: Value.Lit(Tree.IntLit(0)).asArg :: Nil) ne_:: Nil @@ -452,7 +452,7 @@ class TailRecOpt(using State, TL, Raise): val blk = blockBuilder .assignScoped(tupleSym, tupleRes) .assignScoped(sliceResSym, sliceRes) - (blk, mainArgs(sliceResSym.asSimpleRef :: Nil)) + (blk, mainArgs(sliceResSym.asSimpleRef(N) :: Nil)) else (blockBuilder.assignScoped(tupleSym, tupleRes), mainArgs(Nil)) end val @@ -466,7 +466,7 @@ class TailRecOpt(using State, TL, Raise): Scoped( requiredTmps.values.toSet, requiredTmps.toList.foldRight(assignments): - case ((v, l), acc) => Assign(l, v.asSimpleRef, acc)) + case ((v, l), acc) => Assign(l, v.asSimpleRef(N), acc)) // Not a tail call case _ => super.applyBlock(b) @@ -474,7 +474,7 @@ class TailRecOpt(using State, TL, Raise): // Rewrite the result with symbols pointing to the merged function parameters and possibly the copied parameters (see `copiedParams`). val blk = applyBlock(symRewriter.applyBlock(b)) val withCopied = copiedParamSyms.toArray.sortBy(_._1.uid).foldRight(blk): - case ((ogParam, copiedParam), accBlk) => Assign(copiedParam, paramSymsArr(paramsIdxes(ogParam)).asSimpleRef, accBlk) + case ((ogParam, copiedParam), accBlk) => Assign(copiedParam, paramSymsArr(paramsIdxes(ogParam)).asSimpleRef(N), accBlk) Scoped(copiedParamSyms.map(_._2).toSet, withCopied) val arms = funs.map: f => @@ -482,7 +482,7 @@ class TailRecOpt(using State, TL, Raise): val switch = if arms.length === 1 then arms.head._2 - else Match(curIdSym.asSimpleRef, arms, N, End()) + else Match(curIdSym.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)), arms, N, End()) val loop = Label(loopSym, true, switch, End()) @@ -493,7 +493,7 @@ class TailRecOpt(using State, TL, Raise): val rewrittenFuns = if funs.size === 1 then Nil else funs.map: f => - val paramArgs = getParamSyms(f).map(s => s.asSimpleRef.asArg) + val paramArgs = getParamSyms(f).map(s => s.asSimpleRef(N).asArg) val args = Value.Lit(Tree.IntLit(dSymIds(f.dSym))).asArg :: paramArgs @@ -525,7 +525,7 @@ class TailRecOpt(using State, TL, Raise): owner, loopBms, loopDSym, PlainParamList(params) :: Nil, loop)(N, annotations = Annot.Private :: Nil) - val paramArgs = getParamSyms(f).map(s => s.asSimpleRef.asArg) + val paramArgs = getParamSyms(f).map(s => s.asSimpleRef(N).asArg) val internalSel = owner match case Some(value) => Select(value.asThis, Tree.Ident(loopBms.nme))(S(loopDSym)) case None => loopBms.asMemberRef(loopDSym) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala index bbe13057ee..42cabdf482 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala @@ -73,7 +73,7 @@ class UsedVarAnalyzer(b: Block, scopeData: ScopeData)(using State): case _ => super.applyBlock(b) override def applyPath(p: Path): Unit = p match - case Value.SimpleRef(_: BuiltinSymbol) => super.applyPath(p) + case Value.SimpleRef(_: BuiltinSymbol, _) => super.applyPath(p) case RefOfBms(_, SDSym(dSym), _) => val node = scopeData.getNode(dSym) node.obj match diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala index e184e54240..6d43cab888 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala @@ -68,7 +68,7 @@ class WorkerWrapper workerBody, )(fun.configOverride, withoutInline(fun.annotations)) val workerArgs = fun.params.flatMap(_.params).map: param => - Arg(N, param.sym.asSimpleRef) + Arg(N, param.sym.asSimpleRef(N)) val wrapperBody = Return( Call(worker.asPath, workerArgs ne_:: Nil)(isMlsFun = true, mayRaiseEffects = true, explicitTailCall = false), ) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala index 95a8d49cd6..d14e54bcb0 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala @@ -272,7 +272,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): override def applyValue(v: Value): Unit = v match - case Value.SimpleRef(l) if !inCtx(l) => freeVars.add(l) + case Value.SimpleRef(l, _) if !inCtx(l) => freeVars.add(l) case Value.MemberRef(bms, _) if !inCtx(bms) && bms.asClsLike.isEmpty => freeVars.add(bms) case _ => super.applyValue(v) @@ -395,7 +395,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): r match case s@TrackableSelect(from, _, _) => if branchSelSyms.isDefinedAt(s.uid.concreteId) then - k(branchSelSyms(s.uid.concreteId).asSimpleRef) + k(branchSelSyms(s.uid.concreteId).asSimpleRef(N)) else if solver.finalDtorSrcs.contains(s.uid.concreteId) then applyPath(from)(k) else @@ -414,7 +414,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val ctorInfo = solver.fusingCtorInfo(ctor.uid.concreteId) val idx = ctorInfo.args.unzip._1.indexOf(field) val fieldSyms = mkCtorFieldSyms(ctor.uid.concreteId) - args.zip(fieldSyms).foldRight(k(fieldSyms(idx).asSimpleRef)): + args.zip(fieldSyms).foldRight(k(fieldSyms(idx).asSimpleRef(N))): case (Arg(N, a) -> s, rest) => applyPath(a): fusedField => Scoped(Set.single(s), Assign(s, fusedField, rest)) @@ -443,11 +443,11 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): Assign( lambdaSym, callBranchFun, - k(lambdaSym.asSimpleRef)) + k(lambdaSym.asSimpleRef(N))) ) case s@TrackableSelect(from, _, _) => if branchSelSyms.isDefinedAt(s.uid.concreteId) then - k(branchSelSyms(s.uid.concreteId).asSimpleRef) + k(branchSelSyms(s.uid.concreteId).asSimpleRef(N)) else if solver.finalDtorSrcs.contains(s.uid.concreteId) then applyPath(from)(k) else diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala index f1316f6d72..25fa96e0a6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala @@ -33,7 +33,7 @@ object FlowAnalysis: def getResult = resultIdToResult(resultId) def getReferredSym: Symbol = resultId.getResult match - case Value.SimpleRef(s) => s + case Value.SimpleRef(s, _) => s case Value.MemberRef(bms, _) => bms case Value.This(sym) => sym case e => lastWords(s"assumption failed: $e is not a SimpleRef, MemberRef, or ThisRef") @@ -90,7 +90,7 @@ type OriginId = ResultId | FunId /** Extracts the underlying symbol of a variable-like reference, for flow-tracking use. */ object TrackedSymOf: def unapply(p: Value.RefLike | Select)(using Elaborator.State): Opt[Symbol] = p match - case Value.SimpleRef(sym) => S(sym) + case Value.SimpleRef(sym, _) => S(sym) case Value.MemberRef(_, disamb) => S(disamb) case Value.This(sym) => S(sym) case s: Select => s.symbol.flatMap: selSym => @@ -116,8 +116,8 @@ object PossibleTrackableTupleSelect: def unapply(s: Result)(using eState: Elaborator.State): Opt[Value.SimpleRef -> Int] = s match case Call( - Select(Select(Value.SimpleRef(runtimeSym), Tree.Ident("Tuple")), Tree.Ident("get")), - (Arg(N, ref@Value.SimpleRef(scrut)) :: Arg(N, Value.Lit(Tree.IntLit(n))) :: Nil) :: Nil + Select(Select(Value.SimpleRef(runtimeSym, _), Tree.Ident("Tuple")), Tree.Ident("get")), + (Arg(N, ref@Value.SimpleRef(scrut, _)) :: Arg(N, Value.Lit(Tree.IntLit(n))) :: Nil) :: Nil ) if runtimeSym is eState.runtimeSymbol => S(ref -> n.toInt) case _ => N @@ -125,7 +125,7 @@ object TrackableSelect: def unapply(s: Result)(using pre: FlowPreAnalyzer, eState: Elaborator.State): Opt[(from: Path, field: SelField, owner: CtorCls)] = given fState: FlowAnalysis.State = pre.fState s match - case sel@PossibleTrackableTupleSelect((ref@Value.SimpleRef(scrut)) -> ith) => + case sel@PossibleTrackableTupleSelect((ref@Value.SimpleRef(scrut, _)) -> ith) => pre.res.getEnclosingMatchesForSel(sel.uid).find(_._1.getReferredSym is scrut).flatMap: case (_, Some(tupSize: Int)) => S(ref, ith, tupSize) case _ => N @@ -146,7 +146,7 @@ object CtorRef: yield cls def unapply(p: Path)(using Elaborator.State): Opt[ClassSymbol | ModuleOrObjectSymbol] = p match - case Value.SimpleRef(sym) => classCtorSymbol(sym) + case Value.SimpleRef(sym, _) => classCtorSymbol(sym) case Value.MemberRef(_, disamb) => classCtorSymbol(disamb) orElse disamb.asCls orElse disamb.asObj case Value.This(sym) => classCtorSymbol(sym) orElse sym.asCls case s: Select => s.symbol.flatMap(classCtorSymbol) @@ -408,7 +408,7 @@ class FlowPreAnalyzer(val pgrm: Program)(using def isEnclosingMatchScrutSym(sym: Symbol): Boolean = ctx.exists: case InCtx.MtchBody(m, _) => m.scrut match - case Value.SimpleRef(s) => s is sym + case Value.SimpleRef(s, _) => s is sym case Value.MemberRef(bms, disamb) => disamb is sym case _ => false case _ => false @@ -559,7 +559,7 @@ class FlowPreAnalyzer(val pgrm: Program)(using case p: Path => applyPath(p) private def applyValueSimpleRef(v: Value.SimpleRef, recordAffinity: Bool) = - val Value.SimpleRef(l) = v + val Value.SimpleRef(l, _) = v l match case s: TermSymbol => recordRefInCaptures(s) @@ -583,7 +583,7 @@ class FlowPreAnalyzer(val pgrm: Program)(using case p@TrackableFieldSelect(qual, _ -> _) => res.selToCtxOfSel.addOne(p.uid -> ctxTracker.getAllCtx) qual match - case v@Value.SimpleRef(l) + case v@Value.SimpleRef(l, _) if ctxTracker.isEnclosingMatchScrutSym(l) => applyValueSimpleRef(v, recordAffinity = false) case v@Value.MemberRef(bms, disamb) @@ -595,7 +595,7 @@ class FlowPreAnalyzer(val pgrm: Program)(using case v: Value => applyValue(v) override def applyValue(v: Value): Unit = v match - case v@Value.SimpleRef(l) => applyValueSimpleRef(v, recordAffinity = true) + case v@Value.SimpleRef(l, _) => applyValueSimpleRef(v, recordAffinity = true) case v@Value.MemberRef(_, _) => applyValueMemberRef(v, recordAffinity = true) case Value.This(sym) => () case Value.Lit(lit) => () diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index 816d9c3f2a..2e60402bcd 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -179,25 +179,25 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: case Value.MemberRef(bms, disamb) => if disamb.shouldBeLifted then doc"${scope.lookup_!(bms, bms.toLoc)}.class" else scope.lookup_!(bms, r.toLoc) - case Value.SimpleRef(l: BuiltinSymbol) => + case Value.SimpleRef(l: BuiltinSymbol, _) => if l.nullary then l.nme else errExpr(msg"Illegal reference to builtin symbol '${l.nme}'") - case Value.SimpleRef(l: semantics.TermSymbol) => + case Value.SimpleRef(l: semantics.TermSymbol, _) => l.owner match case S(owner) => lastWords(s"Unexpected SimpleRef of TermSymbol with owner: `$l` (owner: `$owner`)") case N => scope.lookup_!(l, r.toLoc) - case Value.SimpleRef(l) => scope.lookup_!(l, r.toLoc) - case Call(Value.SimpleRef(l: BuiltinSymbol), (lhs :: rhs :: Nil) :: Nil) if !l.functionLike => + case Value.SimpleRef(l, _) => scope.lookup_!(l, r.toLoc) + case Call(Value.SimpleRef(l: BuiltinSymbol, _), (lhs :: rhs :: Nil) :: Nil) if !l.functionLike => if l.binary then val res = doc"${operand(lhs)} ${l.nme} ${operand(rhs)}" if needsParens(l.nme) then doc"(${res})" else res else errExpr(msg"Cannot call non-binary builtin symbol '${l.nme}'") - case Call(Value.SimpleRef(l: BuiltinSymbol), (rhs :: Nil) :: Nil) if !l.functionLike => + case Call(Value.SimpleRef(l: BuiltinSymbol, _), (rhs :: Nil) :: Nil) if !l.functionLike => if l.unary then val res = doc"${l.nme} ${operand(rhs)}" if needsParens(l.nme) then doc"(${res})" else res else errExpr(msg"Cannot call non-unary builtin symbol '${l.nme}'") - case Call(Value.SimpleRef(l: BuiltinSymbol), args :: Nil) => + case Call(Value.SimpleRef(l: BuiltinSymbol, _), args :: Nil) => if l.functionLike then val argsDoc = args.map(argument).mkDocument(", ") doc"${l.nme}(${argsDoc})" @@ -317,7 +317,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: val scrutSym = scrut.map(_.sym) b match case Match( - scrut_ @ Value.SimpleRef(scrutSym_), // The scrutinee is a ref. + scrut_ @ Value.SimpleRef(scrutSym_, _), // The scrutinee is a ref. (Case.Lit(Tree.IntLit(curVal_)), b) :: Nil, // There is only one case matching an int literal. S(End(_)) | N, rest // Default case exists and does nothing. ) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 70bb4401d2..da7a1fe74e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -225,7 +225,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): def parentFromPath(p: Path): Ls[Local] = p match case Value.MemberRef(bms, disamb) => fromMemToClass(disamb) :: Nil case Value.This(sym) => fromMemToClass(sym) :: Nil - case Value.SimpleRef(l) => + case Value.SimpleRef(l, _) => // TODO(Derppening): Check if this assertion holds bErrStop(msg"Expected parent to be a MemberRef") case _ => bErrStop(msg"Unsupported parent path ${p.toString()}") @@ -281,7 +281,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): private def bValue(v: Value)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = trace[Node](s"bValue { $v } begin", x => s"bValue end: ${x.show}"): v match - case Value.SimpleRef(l: TermSymbol) if l.owner.nonEmpty => + case Value.SimpleRef(l: TermSymbol, _) if l.owner.nonEmpty => k(l |> sr) case Value.MemberRef(bms, disamb) if bms.nme.isCapitalized => val v: Local = newTemp @@ -293,18 +293,18 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): val paramsList = PlainParamList( (0 until f.paramsSize).zip(tempSymbols).map((_n, sym) => Param(FldFlags.empty, sym, N, Modulefulness.none)).toList) - val app = Call(v, tempSymbols.map(x => Arg(N, x.asSimpleRef)).toList ne_:: Nil)(true, false, false) + val app = Call(v, tempSymbols.map(x => Arg(N, x.asSimpleRef(N))).toList ne_:: Nil)(true, false, false) bLam(Lambda(paramsList, Return(app))(Nil), S(bms.nme), N)(k) case None => k(ctx.findName(bms) |> sr) - case Value.SimpleRef(l) => + case Value.SimpleRef(l, _) => ctx.fn_ctx.get(l) match case Some(f) => val tempSymbols = (0 until f.paramsSize).map(x => newNamed("arg")) val paramsList = PlainParamList( (0 until f.paramsSize).zip(tempSymbols).map((_n, sym) => Param(FldFlags.empty, sym, N, Modulefulness.none)).toList) - val app = Call(v, tempSymbols.map(x => Arg(N, x.asSimpleRef)).toList ne_:: Nil)(true, false, false) + val app = Call(v, tempSymbols.map(x => Arg(N, x.asSimpleRef(N))).toList ne_:: Nil)(true, false, false) bLam(Lambda(paramsList, Return(app))(Nil), S(l.nme), N)(k) case None => k(ctx.findName(l) |> sr) @@ -379,7 +379,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): r match case Call(_, argss) if argss.sizeIs > 1 => bErrStop(msg"Calls with multiple argument lists are not yet supported in LLIR") - case Call(Value.SimpleRef(sym: BuiltinSymbol), argss) => + case Call(Value.SimpleRef(sym: BuiltinSymbol, _), argss) => bArgs(argss.flatten): case args: Ls[TrivialExpr] => val v: Local = newTemp @@ -421,7 +421,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): case args: Ls[TrivialExpr] => val v: Local = newTemp Node.LetCall(Ls(v), builtin, Expr.Literal(Tree.StrLit(mathPrimitive)) :: args, k(v |> sr)) - case Call(s @ Select(r @ Value.SimpleRef(sym), Tree.Ident(fld)), argss) if s.symbol.isDefined => + case Call(s @ Select(r @ Value.SimpleRef(sym, _), Tree.Ident(fld)), argss) if s.symbol.isDefined => bPath(r): case r => bArgs(argss.flatten): @@ -503,7 +503,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): summon[Ctx].def_acc += jpdef Node.Case(e, casesList, defaultCase) case Return(res) => bResult(res)(x => Node.Result(Ls(x))) - case Throw(Instantiate(false, Select(Value.SimpleRef(_), ident), + case Throw(Instantiate(false, Select(Value.SimpleRef(_, _), ident), Ls(Arg(N, Value.Lit(Tree.StrLit(e)))) :: Nil)) if ident.name === "Error" => Node.Panic(e) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 69e93b7f3f..7522f3dc6f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -926,7 +926,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: def splitSuperTail(block: Block): Opt[Block -> Ls[Arg]] = block match case End(_) => N - case Assign(lhs, Call(Value.SimpleRef(bs: BuiltinSymbol), argss), _: End) + case Assign(lhs, Call(Value.SimpleRef(bs: BuiltinSymbol, _), argss), _: End) if (lhs is State.noSymbol) && (bs is State.superSymbol) => S(End("") -> argss.flatten) @@ -1203,7 +1203,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: operands = Seq(ref.i31(i32.const(lit.offset)), ref.i31(i32.const(lit.byteLen))), returnTypes = Seq(Result(RefType.anyref)), ) - case Value.SimpleRef(l) => + case Value.SimpleRef(l, _) => singletonInfoFor(l) match case S(info) => singletonGlobalGet(info) case N => @@ -1236,7 +1236,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ), ) - case Call(Value.SimpleRef(l: BuiltinSymbol), lhs :: rhs :: Nil) if !l.functionLike => + case Call(Value.SimpleRef(l: BuiltinSymbol, _), lhs :: rhs :: Nil) if !l.functionLike => if l.binary then errExpr( Ls( @@ -1283,9 +1283,9 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ) case N => fun match - case Value.SimpleRef(l) => + case Value.SimpleRef(l, _) => val base = fun match - case Value.SimpleRef(l) => ctx.getFunc(l) + case Value.SimpleRef(l, _) => ctx.getFunc(l) case Value.MemberRef(l, _) => ctx.getFunc(l) case _ => N val baseFuncIdx = base match @@ -1486,7 +1486,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: /** Returns the intrinsic name if `path` refers to a builtin under `wasm`, or `N` otherwise. */ private def wasmIntrinsicName(path: Path): Opt[Str] = path match - case Select(Value.SimpleRef(sym), ident) if (sym eq State.wasmSymbol) && wasmIntrinsicNameSet.contains(ident.name) => + case Select(Value.SimpleRef(sym, _), ident) if (sym eq State.wasmSymbol) && wasmIntrinsicNameSet.contains(ident.name) => S(ident.name) case _ => N diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index 98424e1d2d..b1cb610676 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -367,7 +367,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C for (_, s) <- entries do LoweringCtx.loweringCtx.collectScopedSym(s) val objectSym = ctx.builtins.Object mkMatch( // checking that we have an object - Case.Cls(objectSym, BuiltinSymbol(objectSym.nme, false, false, true, false).asSimpleRef), + Case.Cls(objectSym, BuiltinSymbol(objectSym.nme, false, false, true, false).asSimpleRef(codegen.ErasedType.AnyRef(false, objectSym))), entries.foldRight(lowerSplit(tail, cont)): case ((fieldName, fieldSymbol), blk) => mkMatch( @@ -490,7 +490,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C if useNestedScoped then LoweringCtx.loweringCtx.getCollectedSym else Set.empty, mainBlock) // Embed the `body` into `Label` if the term is a `while`. - lazy val rest = if usesResTmp then k(l.asSimpleRef) else k(lowering.unit) + lazy val rest = if usesResTmp then k(l.asSimpleRef(N)) else k(lowering.unit) val block = if form === IfLikeForm.While then // NOTE: `shouldRewriteWhile` is not the same as `config.rewriteWhileLoops` @@ -501,16 +501,16 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C outerCtx.collectScopedSym(loopResult) outerCtx.collectScopedSym(isReturned) val loopEnd: Path = - Select(State.runtimeSymbol.asSimpleRef, Tree.Ident("LoopEnd"))(S(State.loopEndSymbol)) + Select(State.runtimeSymbol.asSimpleRef(N), Tree.Ident("LoopEnd"))(S(State.loopEndSymbol)) val blk = blockBuilder .define(FunDefn(N, f, tSym, PlainParamList(Nil) :: Nil, Begin(body, Return(loopEnd)))(configOverride = N, annotations = Nil)) .assign(loopResult, Call(f.asMemberRef(tSym), Nil ne_:: Nil)(true, true, false)) if summon[LoweringCtx].mayRet then blk - .assign(isReturned, Call(State.builtinOpsMap("!==").asSimpleRef, + .assign(isReturned, Call(State.builtinOpsMap("!==").asSimpleRef(N), (loopResult.asPath.asArg :: loopEnd.asArg :: Nil) ne_:: Nil)(true, false, false)) - .ifthen(isReturned.asSimpleRef, Case.Lit(Tree.BoolLit(true)), - Return(loopResult.asSimpleRef), + .ifthen(isReturned.asSimpleRef(codegen.ErasedType.Primitive(codegen.PrimitiveType.Bool)), Case.Lit(Tree.BoolLit(true)), + Return(loopResult.asSimpleRef(N)), N ) .rest(rest) diff --git a/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls b/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls index 9f749b55ab..27ef775895 100644 --- a/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls +++ b/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls @@ -269,7 +269,7 @@ staged module A with :todo staged module Spread with fun f() = if [1, ..[1, 2]] is [1, ...x] then x else 0 -//│ ═══[COMPILATION ERROR] Spread parameters are not supported in staged module: Arg(Some(Lazy),SimpleRef(tmp:tmp)) +//│ ═══[COMPILATION ERROR] Spread parameters are not supported in staged module: Arg(Some(Lazy),SimpleRef(tmp:tmp,None)) //│ ═══[COMPILATION ERROR] No definition found in scope for member 'tmp' //│ > fun ctor_() = () //│ ═══[RUNTIME ERROR] Error: MLscript call unexpectedly returned `undefined`, the forbidden value. diff --git a/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls b/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls index feaaf5a5e4..360940490b 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls @@ -29,11 +29,13 @@ x + 1 //│ lhs = x⁰ //│ rhs = Lit of IntLit of 1 //│ rest = Return of Call: -//│ fun = SimpleRef of builtin:+⁰ +//│ fun = SimpleRef: +//│ sym = builtin:+⁰ //│ argss = Ls of //│ Ls of //│ Arg: -//│ value = SimpleRef of x⁰ +//│ value = SimpleRef: +//│ sym = x⁰ //│ Arg: //│ value = Lit of IntLit of 1 //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— From d389627652dfdee418acb026574b3fb705a5077a Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 29 May 2026 15:24:51 +0800 Subject: [PATCH 06/48] WIP: Remove `SimpleRef.erasedType` This will be replaced with `Symbol.erasedType`. --- .../src/main/scala/hkmc2/codegen/Block.scala | 31 ++-- .../scala/hkmc2/codegen/BlockSimplifier.scala | 30 ++-- .../hkmc2/codegen/BlockTransformer.scala | 4 +- .../scala/hkmc2/codegen/BlockTraverser.scala | 2 +- .../hkmc2/codegen/BufferableTransform.scala | 16 +- .../scala/hkmc2/codegen/DeadParamElim.scala | 2 +- .../scala/hkmc2/codegen/EtaExpansion.scala | 2 +- .../FirstClassFunctionTransformer.scala | 10 +- .../scala/hkmc2/codegen/HandlerLowering.scala | 30 ++-- .../src/main/scala/hkmc2/codegen/Lifter.scala | 28 ++-- .../main/scala/hkmc2/codegen/Lowering.scala | 142 +++++++++--------- .../main/scala/hkmc2/codegen/Printer.scala | 2 +- .../codegen/ReflectionInstrumenter.scala | 10 +- .../hkmc2/codegen/SpecializedSwitch.scala | 2 +- .../hkmc2/codegen/StackSafeTransform.scala | 10 +- .../scala/hkmc2/codegen/SymbolRefresher.scala | 6 +- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 26 ++-- .../scala/hkmc2/codegen/UsedVarAnalyzer.scala | 2 +- .../scala/hkmc2/codegen/WorkerWrapper.scala | 2 +- .../hkmc2/codegen/deforest/Rewrite.scala | 10 +- .../codegen/flowAnalysis/FlowAnalysis.scala | 20 +-- .../scala/hkmc2/codegen/js/JSBuilder.scala | 14 +- .../scala/hkmc2/codegen/llir/Builder.scala | 16 +- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 12 +- .../hkmc2/semantics/ucs/Normalization.scala | 12 +- 25 files changed, 219 insertions(+), 222 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 7cdb507d16..1f41794194 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -545,7 +545,7 @@ object HandleBlock: N, Nil, S(par), handlerMtds, Nil, Nil, // Apparently, the lifter is not happy with any assignment in the preCtor... - Assign(State.noSymbol, Call(State.builtinOpsMap("super").asSimpleRef(N), args.map(_.asArg) ne_:: Nil)(true, true, false), End()), + Assign(State.noSymbol, Call(State.builtinOpsMap("super").asSimpleRef, args.map(_.asArg) ne_:: Nil)(true, true, false), End()), End(), N, N, @@ -556,7 +556,7 @@ object HandleBlock: .define(clsDefn) .assign(lhs, Instantiate(mut = true, clsDefn.sym.asMemberRef(cls), Nil :: Nil)) .define(bodyDefn) - .assign(res, handleSuspension(lhs.asSimpleRef(ErasedType.AnyRef(false, cls)), bodyDefn.sym.asMemberRef(bodyDefn.dSym))) + .assign(res, handleSuspension(lhs.asSimpleRef, bodyDefn.sym.asMemberRef(bodyDefn.dSym))) .rest(rest) def apply( @@ -899,7 +899,7 @@ sealed abstract class Result extends AutoLocated: // def extraInfo: Str = toLoc.toString def showDbg(using DebugPrinter): Str = this match - case Value.SimpleRef(l, _) => l.showAsPlain + case Value.SimpleRef(l) => l.showAsPlain case Value.MemberRef(l, disamb) => s"${l.showAsPlain}${s"‹${disamb.showAsPlain}›"}" case Value.This(sym) => s"this[${sym.showAsPlain}]" case Value.Lit(lit) => lit.idStr @@ -919,7 +919,7 @@ sealed abstract class Result extends AutoLocated: q.isPure && sel.symbol.exists(_.isPure) case c @ Call(fun, ass) if c.isKnownUnsaturatedCall => fun.isPure && ass.forall(_.forall(a => a.spread.isEmpty && a.value.isPure)) - case Call(Value.SimpleRef(bs: BuiltinSymbol, _), ass) if bs.isPure => + case Call(Value.SimpleRef(bs: BuiltinSymbol), ass) if bs.isPure => ass.forall(_.forall(_.value.isPure)) case Record(mut, args) => args.forall(_.value.isPure) case Tuple(mut, elems) => elems.forall(_.value.isPure) @@ -938,7 +938,7 @@ sealed abstract class Result extends AutoLocated: case Lambda(params, body) => Vector.single(params) case Tuple(mut, elems) => elems.iterator.map(_.value).toVector case Record(mut, elems) => elems.iterator.map(_.value).toVector - case Value.SimpleRef(l, _) => Vector.empty + case Value.SimpleRef(l) => Vector.empty case Value.MemberRef(bms, disamb) => Vector.empty case Value.This(sym) => Vector.empty case Value.Lit(lit) => Vector.single(lit) @@ -960,7 +960,7 @@ sealed abstract class Result extends AutoLocated: case Tuple(mut, elems) => elems.flatMap(_.value.freeVars).toSet case Record(mut, args) => args.flatMap(arg => arg.idx.fold(Set.empty)(_.freeVars) ++ arg.value.freeVars).toSet - case Value.SimpleRef(l, _) => Set(l) + case Value.SimpleRef(l) => Set(l) case Value.MemberRef(bms, _) => Set(bms) case Value.This(sym) => Set.empty case Value.Lit(lit) => Set.empty @@ -974,12 +974,12 @@ sealed abstract class Result extends AutoLocated: case Tuple(mut, elems) => elems.flatMap(_.value.freeVarsLLIR).toSet case Record(mut, args) => args.flatMap(arg => arg.idx.fold(Set.empty)(_.freeVarsLLIR) ++ arg.value.freeVarsLLIR).toSet - case Value.SimpleRef(l: (BuiltinSymbol | TermSymbol), _) => Set.empty - case Value.SimpleRef(l: DefinitionSymbol[?], _) => l.defn match + case Value.SimpleRef(l: (BuiltinSymbol | TermSymbol)) => Set.empty + case Value.SimpleRef(l: DefinitionSymbol[?]) => l.defn match case Some(d: ClassLikeDef) => Set.empty case Some(d: TermDefinition) if d.companionClass.isDefined => Set.empty case _ => Set(l) - case Value.SimpleRef(l, _) => Set(l) + case Value.SimpleRef(l) => Set(l) case Value.MemberRef(l: (ClassSymbol | TermSymbol), disamb) => Set.empty case Value.MemberRef(l, disamb) => disamb.defn match case Some(d: ClassLikeDef) => Set.empty @@ -1048,7 +1048,7 @@ case class Select(qual: Path, name: Tree.Ident)(val symbol: Opt[DefinitionSymbol case class DynSelect(qual: Path, fld: Path, arrayIdx: Bool) extends Path enum Value extends Path with HasErasedType with ProductWithExtraInfo: - case SimpleRef(sym: LocalVarSymbol | BuiltinSymbol, _erasedType: Opt[ErasedType]) + case SimpleRef(sym: LocalVarSymbol | BuiltinSymbol) /** * @param disamb The symbol disambiguating the definition that the reference refers to. */ @@ -1080,7 +1080,7 @@ object Value: extension (r: RefLike) def symbol: Symbol = r match - case SimpleRef(l, _) => l + case SimpleRef(l) => l case MemberRef(bms, _) => bms case This(sym) => sym @@ -1088,7 +1088,7 @@ object Value: object Ref: def apply(l: Local, disamb: Opt[DefinitionSymbol[?]]): Value.RefLike = l match - case l: (LocalVarSymbol | BuiltinSymbol) => l.asSimpleRef(N) + case l: (LocalVarSymbol | BuiltinSymbol) => l.asSimpleRef case bms: BlockMemberSymbol => bms.asMemberRef: disamb.getOrElse: lastWords(s"Cannot disambiguate overloaded member symbol ${bms.nme}: no disambiguation provided") @@ -1103,7 +1103,7 @@ object Value: def apply(l: TempSymbol | VarSymbol | BuiltinSymbol): Value.RefLike = Ref(l, N) def unapply(v: Value): Opt[(Local, Opt[DefinitionSymbol[?]])] = v match - case SimpleRef(l, _) => S(l -> N) + case SimpleRef(l) => S(l -> N) case MemberRef(bms, disamb) => S(bms -> S(disamb)) case This(sym) => S(sym -> N) case _ => N @@ -1141,10 +1141,7 @@ extension (k: Block => Block) def blockBuilder: Block => Block = identity extension (s: (LocalVarSymbol | BuiltinSymbol)) - @deprecated("Use the overload accepting `Opt[ErasedType]` instead.") - inline def asSimpleRef: Value.SimpleRef = Value.SimpleRef(s, N) - inline def asSimpleRef(erasedType: ErasedType): Value.SimpleRef = Value.SimpleRef(s, S(erasedType)) - inline def asSimpleRef(erasedType: Opt[ErasedType]): Value.SimpleRef = Value.SimpleRef(s, erasedType) + inline def asSimpleRef: Value.SimpleRef = Value.SimpleRef(s) extension (bms: BlockMemberSymbol) inline def asMemberRef(disamb: DefinitionSymbol[?]): Value.MemberRef = Value.MemberRef(bms, disamb) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala index a3cc2de9c7..3292609efd 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala @@ -129,7 +129,7 @@ class BlockSimplifier case ts: TermSymbol => usedPrivateFields += ts case _ => - case Value.SimpleRef(loc, _) => + case Value.SimpleRef(loc) => usedVars += loc case Value.MemberRef(loc, _) => usedVars += loc @@ -216,7 +216,7 @@ class BlockSimplifier override def applyValue(v: Value)(k: Value => Block) = v match // * Replace with `undefined` those references to local variables that are never assigned - case Value.SimpleRef(loc, _) if localVars.contains(loc) && !definedVars.contains(loc) => + case Value.SimpleRef(loc) if localVars.contains(loc) && !definedVars.contains(loc) => registerChange(s"${loc.showDbg} is never assigned; replacing read with undefined") // if !symbolsToPreserve(loc) then removedLocals += loc k(Value.Lit(syntax.Tree.UnitLit(false))) @@ -503,7 +503,7 @@ class BlockSimplifier // * Discard local variables that are assigned just to be returned // * Note: the reason we do this here and not in DeadCodeElim is that we need to check `capturedVars` - case Assign(lhs: LocalVar, rhs, Return(Value.SimpleRef(ret, _))) + case Assign(lhs: LocalVar, rhs, Return(Value.SimpleRef(ret))) if !inDryRun && (ret is lhs) && !capturedVars(lhs) && !symbolsToPreserve(lhs) => registerChange(s"tail-return ${lhs.showDbg} ~> ${rhs.showDbg}") @@ -513,7 +513,7 @@ class BlockSimplifier // log(s"Propagating ${lhs} := ${rhs} (${assignedResults.get(lhs)})") assignedResults += lhs -> Assigned(ass, rhs.match - case r @ Value.SimpleRef(sym: LocalVar, _) => + case r @ Value.SimpleRef(sym: LocalVar) => if capturedVars(sym) then N else val rhs2 = assignedResults(sym) @@ -640,9 +640,9 @@ class BlockSimplifier if gaveUp then Set.empty else p match - case Value.SimpleRef(r: LocalVar, _) if capturedVars(r) => + case Value.SimpleRef(r: LocalVar) if capturedVars(r) => giveUp - case Value.SimpleRef(r: LocalVar, _) => + case Value.SimpleRef(r: LocalVar) => assignedResults.get(r).fold(giveUp)(getShapesA) case Value.MemberRef(r, sym: ModuleOrObjectSymbol) => Set.single(sym) @@ -724,7 +724,7 @@ class BlockSimplifier override def applyValue(v: Value)(k: Value => Block): Block = v match - case Value.SimpleRef(loc: LocalVar, _) if !inDryRun && !capturedVars(loc) => + case Value.SimpleRef(loc: LocalVar) if !inDryRun && !capturedVars(loc) => val rs = assignedResults(loc) // log(s"Ref ${loc.showDbg} ${rs} ${localVars(loc)} ${capturedVars(loc)}") @@ -765,7 +765,7 @@ class BlockSimplifier case _ => litValue = false opt match - case S((r @ Value.SimpleRef(lv: LocalVar, _)) -> rhs) => + case S((r @ Value.SimpleRef(lv: LocalVar)) -> rhs) => if assignedResults(lv) is rhs then Set.single(r) ++ analyzeValues(rhs) else Set.empty @@ -813,7 +813,7 @@ class BlockSimplifier case call: Call if call.isKnownUnsaturatedCall && call.isPure => S(call) case _ => opt match - case S((Value.SimpleRef(next: LocalVar, _), nextAsst)) + case S((Value.SimpleRef(next: LocalVar), nextAsst)) if !capturedVars(next) && !seen(next) && (assignedResults(next) is nextAsst) => loop(nextAsst, seen + next) case _ => N @@ -829,7 +829,7 @@ class BlockSimplifier r match // * Try to propagate pure calls - case Value.SimpleRef(loc: LocalVar, _) if !inDryRun && !capturedVars(loc) => + case Value.SimpleRef(loc: LocalVar) if !inDryRun && !capturedVars(loc) => assignedPureCallPrefix(loc) match case S(call) => registerChange(s"${loc.showDbg} ~> ${call.showDbg}") @@ -838,7 +838,7 @@ class BlockSimplifier super.applyResult(r)(k) // * Try to combine pure calls (typically unsaturated calls) assigned to a variable into the current call - case c @ Call(Value.SimpleRef(loc: LocalVar, _), argss) if !inDryRun && !capturedVars(loc) => + case c @ Call(Value.SimpleRef(loc: LocalVar), argss) if !inDryRun && !capturedVars(loc) => assignedPureCallPrefix(loc) match case S(prefix) => registerChange(s"${loc.showDbg} call prefix ~> ${prefix.showDbg}") @@ -849,13 +849,13 @@ class BlockSimplifier case N => super.applyResult(r)(k) // * Remove uses of the strange builtin comma operator - case Call(Value.SimpleRef(sym: BuiltinSymbol, _), (arg1 :: arg2 :: Nil) :: Nil) + case Call(Value.SimpleRef(sym: BuiltinSymbol), (arg1 :: arg2 :: Nil) :: Nil) if sym.nme === "," && arg1.spread.isEmpty && arg2.spread.isEmpty => Assign.discard(arg1.value, k(arg2.value)) // * Partially evaluate calls to known builtins with literal arguments - case Call(Value.SimpleRef(sym: BuiltinSymbol, _), args :: Nil) if args.forall(_.value.isInstanceOf[Value]) => + case Call(Value.SimpleRef(sym: BuiltinSymbol), args :: Nil) if args.forall(_.value.isInstanceOf[Value]) => val argValues = args.map(_.value.asInstanceOf[Value]) args.foreach(a => assert(a.spread.isEmpty)) builtinEval.lift((sym.nme, argValues)) match @@ -1183,10 +1183,10 @@ class BlockSimplifier val copier = Copier(resSym, mapping) val newBlk = copier.applyBlock(blk) if extraArgss.isEmpty then - acc(Scoped(Set.single(resSym), newBlk(k(resSym.asSimpleRef(N))))) + acc(Scoped(Set.single(resSym), newBlk(k(resSym.asSimpleRef)))) else acc(Scoped(Set(resSym), newBlk( - k(Call(resSym.asSimpleRef(N), extraArgss.ne_!)(c.isMlsFun, c.mayRaiseEffects, false))))) + k(Call(resSym.asSimpleRef, extraArgss.ne_!)(c.isMlsFun, c.mayRaiseEffects, false))))) case (sym, value) :: argRest => val newSym = VarSymbol(sym.id) go(acc.assignScoped(newSym, value), argRest, mapping + (sym -> newSym)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala index b30d299dbf..c30ebbddfc 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala @@ -178,12 +178,12 @@ class BlockTransformer(subst: SymbolSubst): case v: Value => applyValue(v)(k) def applyValue(v: Value)(k: Value => Block) = v match - case Value.SimpleRef(l, _) => + case Value.SimpleRef(l) => val l2 = applyLocal(l) match case l: (LocalVarSymbol | BuiltinSymbol) => l case l2 => lastWords(s"Expected applyValue on `$l` (${l.getClass.getSimpleName}) to create a symbol of the same type, but got `$l2` (${l2.getClass.getSimpleName})") - k(if (l2 is l) then v else l2.asSimpleRef(N).withLocOf(v)) + k(if (l2 is l) then v else l2.asSimpleRef.withLocOf(v)) case Value.MemberRef(bms, disamb) => val bms2 = bms.subst val disamb2 = disamb.subst diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala index 33b04fd3af..24f44aa1ee 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala @@ -71,7 +71,7 @@ class BlockTraverser: case v: Value => applyValue(v) def applyValue(v: Value): Unit = v match - case Value.SimpleRef(l, _) => l.traverse + case Value.SimpleRef(l) => l.traverse case Value.MemberRef(bms, disamb) => bms.traverse disamb.traverse diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala index 9e1bf3b175..d5a102b1ed 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala @@ -36,17 +36,17 @@ class BufferableTransform()(using Ctx, State, Raise): Param(p.flags, varMap(p.sym), p.sign, p.modulefulness) (params.map(pl => ParamList(pl.flags, pl.params.map(mapParam), pl.restParam.map(mapParam))), varMap.toMap) def mkFieldReplacer(buf: VarSymbol, baseIdx: VarSymbol, symMap: Map[Symbol, Symbol]) = - val baseIdxRef = baseIdx.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)) + val baseIdxRef = baseIdx.asSimpleRef def getOffset(off: Int)(k: Path => Block): Block = val idxSymbol = new TempSymbol(N, "idx") - val idxSymbolRef = idxSymbol.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)) - Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef(N), (baseIdxRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), - k(DynSelect(buf.asSimpleRef(N).selSN("buf"), idxSymbolRef, true)))) + val idxSymbolRef = idxSymbol.asSimpleRef + Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdxRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), + k(DynSelect(buf.asSimpleRef.selSN("buf"), idxSymbolRef, true)))) def assignToOffset(off: Int, r: Result, rst: Block) = val idxSymbol = new TempSymbol(N, "idx") - val idxSymbolRef = idxSymbol.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)) - Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef(N), (baseIdxRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), - AssignDynField(buf.asSimpleRef(N).selSN("buf"), idxSymbolRef, true, r, applyBlock(rst)))) + val idxSymbolRef = idxSymbol.asSimpleRef + Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdxRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), + AssignDynField(buf.asSimpleRef.selSN("buf"), idxSymbolRef, true, r, applyBlock(rst)))) new BlockTransformer(SymbolSubst.Id): override def applyLocal(sym: Local): Local = symMap.getOrElse(sym, sym) override def applyBlock(b: Block): Block = b match @@ -82,7 +82,7 @@ class BufferableTransform()(using Ctx, State, Raise): val blk = mkFieldReplacer(buf, idx, symMap).applyBlock(f.body) FunDefn(f.owner, f.sym, TermSymbol(f.dSym.k, f.dSym.owner, f.dSym.id), PlainParamList( Param(FldFlags.empty, buf, N, Modulefulness.none) :: Param(FldFlags.empty, idx, N, Modulefulness.none) :: Nil) :: newParams, - if isCtor then Begin(blk, Return(idx.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)))) else blk)(configOverride = f.configOverride, annotations = f.annotations) + if isCtor then Begin(blk, Return(idx.asSimpleRef)) else blk)(configOverride = f.configOverride, annotations = f.annotations) val fakeCtor = transformFunDefn(FunDefn.withFreshSymbol( S(companionSym), BlockMemberSymbol("ctor", Nil, false), diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala index 61aa63d63c..2b75982a87 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala @@ -237,7 +237,7 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise): case _ => super.applyPath(p)(k) override def applyValue(v: Value)(k: Value => Block): Block = v match - case ref@Value.SimpleRef(l: VarSymbol, _) if activeEliminatedParams(l) => + case ref@Value.SimpleRef(l: VarSymbol) if activeEliminatedParams(l) => k(Value.Lit(Tree.UnitLit(false)).withLocOf(ref)) case _ => super.applyValue(v)(k) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala index ca274b3ea1..7bcca2730a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala @@ -175,7 +175,7 @@ class EtaExpansionRewrite(val etaExpansionSolver: EtaExpansionSolver)(using Rais Param.simple(new VarSymbol(new Tree.Ident(s"eta$$$idx$$$i"))) EtaParamList( ParamList(ParamListFlags.empty, params, N), - params.map(p => Arg(N, p.sym.asSimpleRef(N))), + params.map(p => Arg(N, p.sym.asSimpleRef)), ) else lastWords("not the same shape?") diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala index 7af1152a95..0d924e89ce 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala @@ -28,11 +28,11 @@ class FirstClassFunctionTransformer ) val defSym = new BlockMemberSymbol("Function$", Nil, false) val callDef = FunDefn.withFreshSymbol(Some(clsSym), new BlockMemberSymbol("call", Nil, true), params :: Nil, - Return(Call(p, params.params.map(_.sym.asSimpleRef(N).asArg) ne_:: Nil)(true, false, false)))(N, annotations = Nil) + Return(Call(p, params.params.map(_.sym.asSimpleRef.asArg) ne_:: Nil)(true, false, false)))(N, annotations = Nil) ClsLikeDefn(None, clsSym, defSym, None, syntax.Cls, None, Nil, Some(Select(State.globalThisSymbol.asThis, Tree.Ident("Function"))(Some(ctx.builtins.Function))), callDef :: Nil, Nil, Nil, Assign.discard( - Call(State.builtinOpsMap("super").asSimpleRef(N), Nil ne_:: Nil)(false, false, false), + Call(State.builtinOpsMap("super").asSimpleRef, Nil ne_:: Nil)(false, false, false), End()), End(), None, None)(N, annotations = Nil) private def getParamList(l: BlockMemberSymbol): Option[ParamList] = funDefns.get(l) match @@ -56,7 +56,7 @@ class FirstClassFunctionTransformer Assign( tmp, Instantiate(false, cls, Nil :: Nil), - k(tmp.asSimpleRef(ErasedType.AnyRef(rsc = false, clsDef.isym.asClsOrMod.get))), + k(tmp.asSimpleRef), ) ) ) @@ -79,7 +79,7 @@ class FirstClassFunctionTransformer Assign( tmp, Instantiate(false, cls, Nil :: Nil), - k(tmp.asSimpleRef(ErasedType.AnyRef(rsc = false, clsDef.isym.asClsOrMod.get))), + k(tmp.asSimpleRef), ) ) ) @@ -100,7 +100,7 @@ class FirstClassFunctionTransformer case c @ Call(fun, argss) => applyListOf(argss, (args, k2) => applyArgs(args)(k2)): argss2 => def call(f: Path) = Call(f, argss2.ne_!)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall) fun match - case ref @ Value.SimpleRef(sym, N) => sym match + case ref @ Value.SimpleRef(sym) => sym match case _: VarSymbol | _: TempSymbol => k(call(ref.selSN("call"))) case _ => k(call(fun)) case ref @ Value.MemberRef(_, _) => k(call(fun)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index e9b4e50f97..11263f63b2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -92,7 +92,7 @@ object HandlerLowering: import HandlerLowering.* class HandlerPaths(using Elaborator.State): - val runtimePath: Path = State.runtimeSymbol.asSimpleRef(N) + val runtimePath: Path = State.runtimeSymbol.asSimpleRef val contClsPath: Path = runtimePath.selSN("FunctionContFrame").selSN("class") val mkEffectPath: Path = runtimePath.selSN("mkEffect") val handleBlockImplPath: Path = runtimePath.selSN("handleBlockImpl") @@ -138,18 +138,18 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, object StateTransition: private val transitionSymbol = freshTmp("transition") def apply(uid: StateId) = - Return(PureCall(transitionSymbol.asSimpleRef(N), List(Value.Lit(Tree.IntLit(uid))))) + Return(PureCall(transitionSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid))))) def unapply(blk: Block) = blk match - case Return(PureCall(Value.SimpleRef(`transitionSymbol`, N), List(Value.Lit(Tree.IntLit(uid))))) => + case Return(PureCall(Value.SimpleRef(`transitionSymbol`), List(Value.Lit(Tree.IntLit(uid))))) => S(uid) case _ => N object Unwind: private val unwindSymbol = freshTmp("unwind") def apply(uid: StateId, loc: Value) = - Return(PureCall(unwindSymbol.asSimpleRef(N), List(Value.Lit(Tree.IntLit(uid)), loc))) + Return(PureCall(unwindSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid)), loc))) def unapply(blk: Block) = blk match - case Return(PureCall(Value.SimpleRef(`unwindSymbol`, N), List(Value.Lit(Tree.IntLit(uid)), loc: Value))) => + case Return(PureCall(Value.SimpleRef(`unwindSymbol`), List(Value.Lit(Tree.IntLit(uid)), loc: Value))) => S(uid, loc) case _ => N @@ -540,9 +540,9 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val debugInfoSym = freshTmp(s"$debugNme$$debugInfo") // TODO: properly support spread argument by calculating the correct length. val rtArgLists = intLit(fun.params.length) :: fun.params.flatMap: pl => - intLit(pl.params.length) :: pl.params.map(p => p.sym.asSimpleRef(N)) + intLit(pl.params.length) :: pl.params.map(p => p.sym.asSimpleRef) val newCtx = HandlerCtx.FunctionLike(FunctionCtx(funcPath, thisPath, ResumeInfo(rtArgLists, varList, L(fun.sym)), - DebugInfo(debugNme, if opt.debug then debugInfoSym.asSimpleRef(N) else unit), thisPath.isDefined && fun.params.isEmpty)) + DebugInfo(debugNme, if opt.debug then debugInfoSym.asSimpleRef else unit), thisPath.isDefined && fun.params.isEmpty)) val bod2 = translateBlock(fun.body, newCtx, scopedVars) val fun2 = if fun.body is bod2 then fun else FunDefn(fun.owner, fun.sym, fun.dSym, fun.params, bod2)(fun.configOverride, fun.annotations) @@ -658,7 +658,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val headTransformed = segmentTailTransform.applyBlock(parts.states(head).blk) val initial: Block => Block = blk => Match( - pcVar.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)), + pcVar.asSimpleRef, Case.Lit(Tree.IntLit(head)) -> headTransformed :: Nil, N, blk @@ -670,7 +670,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val transformed = transformState(uid) blk => Match( - pcVar.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)), + pcVar.asSimpleRef, Case.Lit(Tree.IntLit(uid)) -> transformed :: Nil, N, acc(blk) @@ -691,17 +691,17 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, def getSaved(off: BigInt): (Block => Block, Path) = if off == 0 then return (id, DynSelect(paths.runtimePath.selSN("resumeArr"), paths.runtimePath.selSN("resumeIdx"), true)) - val addOne = Assign(getSavedTmp, Call(State.builtinOpsMap("+").asSimpleRef(N), (paths.runtimePath.selSN("resumeIdx").asArg :: intLit(off).asArg :: Nil) ne_:: Nil)(false, false, false), _) - (addOne, DynSelect(paths.runtimePath.selSN("resumeArr"), getSavedTmp.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)), true)) + val addOne = Assign(getSavedTmp, Call(State.builtinOpsMap("+").asSimpleRef, (paths.runtimePath.selSN("resumeIdx").asArg :: intLit(off).asArg :: Nil) ne_:: Nil)(false, false, false), _) + (addOne, DynSelect(paths.runtimePath.selSN("resumeArr"), getSavedTmp.asSimpleRef, true)) - val resumeArrIndexed = DynSelect(paths.runtimePath.selSN("resumeArr"), getSavedTmp.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)), true) - val plus = State.builtinOpsMap("+").asSimpleRef(N) + val resumeArrIndexed = DynSelect(paths.runtimePath.selSN("resumeArr"), getSavedTmp.asSimpleRef, true) + val plus = State.builtinOpsMap("+").asSimpleRef val preRestore = blockBuilder .assign(pcVar, paths.resumePc) .scopedVars(Set(getSavedTmp)) val restoreVars = vars.zipWithIndex.foldLeft(preRestore): case (builder, (local, idx)) => builder - .assign(getSavedTmp, if idx == 0 then paths.resumeIdx else Call(plus, (getSavedTmp.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)).asArg :: intLit(1).asArg :: Nil) ne_:: Nil)(false, false, false)) + .assign(getSavedTmp, if idx == 0 then paths.resumeIdx else Call(plus, (getSavedTmp.asSimpleRef.asArg :: intLit(1).asArg :: Nil) ne_:: Nil)(false, false, false)) .assign(local, resumeArrIndexed) Scoped( @@ -737,7 +737,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case EffectfulResult(r) => // Fallback case, this may lead to unnecessary assignments if it is assign-like val l = freshTmp() - Scoped(Set(l), effectCheck(l, r, k(l.asSimpleRef(N)))) + Scoped(Set(l), effectCheck(l, r, k(l.asSimpleRef))) case _ => super.applyResult(r)(k) topLevelTransform.applyBlock(b) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index ee06f97bda..56a802ebc3 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -387,7 +387,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): case Some(value) => syms.addOne(FunSyms(l, d) -> value) value - k(newSym.asSimpleRef(N)) + k(newSym.asSimpleRef) // Naked reference to a parameterized class constructor (used as a first-class function). // Replace with a partially applied curried C$ wrapper. @@ -404,7 +404,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): case Some(value) => syms.addOne(FunSyms(l, d) -> value) value - k(newSym.asSimpleRef(N)) + k(newSym.asSimpleRef) case _ => resolveDefnRef(l, d, ctor) match case Some(value) => k(value) @@ -557,7 +557,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val vd = ValDefn( tSym, fldSym, - varSym.asSimpleRef(N) + varSym.asSimpleRef )(N, Nil) (sym -> varSym, p, vd) @@ -833,7 +833,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): */ sealed trait GenericRewrittenScope[T] extends RewrittenScope[T]: lazy val captureSym = VarSymbol(Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath = captureSym.asSimpleRef(N) + override lazy val capturePath = captureSym.asSimpleRef protected val liftedObjsOrdered: List[InnerSymbol] = node.liftedObjSyms.toList.sortBy(_.uid) protected val liftedObjsSyms: Map[InnerSymbol, VarSymbol] = liftedObjsOrdered.map: s => s -> VarSymbol(Tree.Ident(s.nme + "$")) @@ -848,7 +848,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): */ sealed trait ClsLikeRewrittenScope[T](sym: InnerSymbol) extends RewrittenScope[T]: lazy val captureSym = TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath = captureSym.asSimpleRef(N) + override lazy val capturePath = captureSym.asSimpleRef protected val liftedObjsOrdered: List[InnerSymbol] = node.liftedObjSyms.toList.sortBy(_.uid) protected val liftedObjsSyms: Map[InnerSymbol, TermSymbol] = liftedObjsOrdered.map: s => s -> TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(s.nme + "$")) @@ -915,7 +915,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): with ClsLikeRewrittenScope[ClsLikeDefn](obj.cls.isym): private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath: Path = captureSym.asSimpleRef(N) + override lazy val capturePath: Path = captureSym.asSimpleRef override def rewriteImpl: LifterResult[ClsLikeDefn] = val rewriterCtor = new BlockRewriter @@ -938,7 +938,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): with ClsLikeRewrittenScope[ClsLikeBody](obj.clsBody.isym): private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.clsBody.isym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath: Path = captureSym.asSimpleRef(N) + override lazy val capturePath: Path = captureSym.asSimpleRef override def rewriteImpl: LifterResult[ClsLikeBody] = val rewriterCtor = new BlockRewriter @@ -966,7 +966,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): .toMap override protected val passedSymsMap = passedSymsMap_.view.mapValues(_.asLocalPath).toMap - override protected val capSymsMap = capSymsMap_.view.mapValues(s => s.asSimpleRef(N)).toMap + override protected val capSymsMap = capSymsMap_.view.mapValues(s => s.asSimpleRef).toMap override protected val passedDefnsMap = defnSymsMap_.view.mapValues(_.asDefnRef).toMap val auxParams: List[Param] = @@ -1012,10 +1012,10 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): case Nil => lastWords("tried to make an aux defn for a function with no parameter list") val args = restSym match case Some(value) => - val tail = Arg(S(SpreadKind.Eager), value.asSimpleRef(N)) :: Nil + val tail = Arg(S(SpreadKind.Eager), value.asSimpleRef) :: Nil syms.foldLeft(tail): - case (acc, sym) => Arg(N, sym.asSimpleRef(N)) :: acc - case None => syms.map(s => Arg(N, s.asSimpleRef(N))) + case (acc, sym) => Arg(N, sym.asSimpleRef) :: acc + case None => syms.map(s => Arg(N, s.asSimpleRef)) val call = Call(fun.sym.asMemberRef(fun.dSym), args ne_:: Nil)(true, true, false) val bod = Return(call) @@ -1065,7 +1065,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): with ClsLikeRewrittenScope[ClsLikeDefn](obj.cls.isym): private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath: Path = captureSym.asSimpleRef(N) + override lazy val capturePath: Path = captureSym.asSimpleRef private val passedSymsMap_ : Map[Local, (vs: VarSymbol, ts: TermSymbol)] = passedSymsOrdered.map: s => s -> @@ -1131,8 +1131,8 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): // Uses the symbols from pl1. def applyPlToPl(pl1: ParamList, pl2: ParamList): List[Arg] = (pl1.restParam, pl2.restParam) match - case (S(rp), S(_)) => pl1.params.foldRight(Arg(S(SpreadKind.Eager), rp.sym.asSimpleRef(N)) :: Nil)((p, ls) => p.sym.asSimpleRef(N).asArg :: ls) - case (N, N) => pl1.paramSyms.map(s => s.asSimpleRef(N).asArg) + case (S(rp), S(_)) => pl1.params.foldRight(Arg(S(SpreadKind.Eager), rp.sym.asSimpleRef) :: Nil)((p, ls) => p.sym.asSimpleRef.asArg :: ls) + case (N, N) => pl1.paramSyms.map(s => s.asSimpleRef.asArg) case _ => die // If class has a main param list, the aux list comes after it diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 496e44b5db..78c5b01b80 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -57,7 +57,7 @@ class LoweringCtx( Subst(map + kv) */ def apply(v: Value): Value = v match - case Value.SimpleRef(l, _) => map.getOrElse(l, v) + case Value.SimpleRef(l) => map.getOrElse(l, v) case _ => v object LoweringCtx: def loweringCtx(using sub: LoweringCtx): LoweringCtx = sub @@ -116,7 +116,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): private def wasmIntrinsicPath(sym: BuiltinSymbol, unary: Bool): Opt[Path] = if config.target is CompilationTarget.Wasm then val map = if unary then wasmUnaryIntrinsicMap else wasmBinaryIntrinsicMap - map.get(sym.nme).map(name => State.wasmSymbol.asSimpleRef(N).selN(Tree.Ident(name))) + map.get(sym.nme).map(name => State.wasmSymbol.asSimpleRef.selN(Tree.Ident(name))) else N private lazy val wasmIntrinsicSymbols: Set[BlockMemberSymbol] = Set( ctx.builtins.wasm.plus_impl, @@ -136,10 +136,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ) lazy val unreachableFn = - Select(State.runtimeSymbol.asSimpleRef(N), Tree.Ident("unreachable"))(S(State.unreachableSymbol)) + Select(State.runtimeSymbol.asSimpleRef, Tree.Ident("unreachable"))(S(State.unreachableSymbol)) def unit: Path = - Select(State.runtimeSymbol.asSimpleRef(N), Tree.Ident("Unit"))(S(State.unitSymbol)) + Select(State.runtimeSymbol.asSimpleRef, Tree.Ident("Unit"))(S(State.unitSymbol)) // type Rcd = (mut: Bool, args: List[RcdArg]) // * Better, but Scala's patmat exhaustiveness chokes on it @@ -155,7 +155,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): source = Diagnostic.Source.Compilation ) lowerSuperCtorCall( - State.builtinOpsMap("super").asSimpleRef(N), + State.builtinOpsMap("super").asSimpleRef, isMlsFun = true, isTailCall = false, args.headOption, @@ -415,7 +415,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case acc: NELs[Ls[Arg]] => val tmp = loweringCtx.registerTempSymbol(N, "baseCall") val call = Call(fr, acc)(isMlsFun, true, isTailCall).withLoc(loc) - Assign(tmp, call, lowerRemainingCalls(tmp.asSimpleRef(N), args, remainingArgss, isTailCall, loc)(k)) + Assign(tmp, call, lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall, loc)(k)) case (_ :: _, Nil) => k(Call(fr, acc.reverse.ne_!)(isMlsFun, true, isTailCall).withLoc(loc)) fr.targetSymbol match @@ -435,7 +435,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case args :: remainingArgss => val tmp = loweringCtx.registerTempSymbol(N, "callPrefix") Assign(tmp, call, - lowerRemainingCalls(tmp.asSimpleRef(N), args, remainingArgss, isTailCall, loc)(k)) + lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall, loc)(k)) /** Lower an instantiation with multiple argument lists into `Instantiate` and `Call` nodes, * trying to group as many as possible into a single `Instantiate` @@ -459,7 +459,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case (Nil, args :: remainingArgss) => val tmp = loweringCtx.registerTempSymbol(N, "baseInst") Assign(tmp, buildInstantiate(acc.reverse), - lowerRemainingCalls(tmp.asSimpleRef(N), args, remainingArgss, isTailCall = false, N)(k)) + lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall = false, N)(k)) case (remainingParamss, Nil) => // * Eta-expand missing argument lists by creating lambdas for each remaining param list. // * This makes partial `new C(args...)` explicit instead of relying on the JS class curry. @@ -471,7 +471,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): softTODO(ps.restParam.isEmpty, "Eta expanding rest parameters in constructor definitions is not yet supported") val freshParams = (ps.params zip freshSyms).map((p, s) => Param(p.flags, s, N, p.modulefulness)) val freshParamList = ParamList(ps.flags, freshParams, N) - val freshArgs = freshSyms.map(s => Arg(N, s.asSimpleRef(N))) + val freshArgs = freshSyms.map(s => Arg(N, s.asSimpleRef)) Lambda(freshParamList, Return(etaExpand(rest, accArgss :+ freshArgs)))(Nil) k(etaExpand(remainingParamss, acc.reverse)) // * Resolve the class definition to get the constructor param lists. @@ -498,7 +498,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case remainingArgss => val tmp = loweringCtx.registerTempSymbol(N, "baseInst") Assign(tmp, buildInstantiate(as :: Nil), - lowerRemainingCalls(tmp.asSimpleRef(N), remainingArgss.head, remainingArgss.tail, isTailCall = false, N)(k)) + lowerRemainingCalls(tmp.asSimpleRef, remainingArgss.head, remainingArgss.tail, isTailCall = false, N)(k)) else zipArgs(ctorParamLists, args, Nil) def lowerArgs(arg: Term)(k: Ls[Arg] => Block)(using LoweringCtx): Block = @@ -616,7 +616,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): warnStmt (sym, disamb) match case (sym: (LocalVarSymbol | BuiltinSymbol), _) => - k(loweringCtx(sym.asSimpleRef(N).withLocOf(ref))) + k(loweringCtx(sym.asSimpleRef.withLocOf(ref))) case (sym: BlockMemberSymbol, _) => k(loweringCtx(sym.asMemberRef(disamb.orElse(sym.asPrincipal).get).withLocOf(ref))) case (sym: InnerSymbol, _) => @@ -668,13 +668,13 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Assign( isContinue, Call( - State.builtinOpsMap("===").asSimpleRef(N), - (bodyResult.asSimpleRef(N).asArg :: State.runtimeSymbol.asSimpleRef(N).selSN("Continue").asArg :: Nil) ne_:: Nil, + State.builtinOpsMap("===").asSimpleRef, + (bodyResult.asSimpleRef.asArg :: State.runtimeSymbol.asSimpleRef.selSN("Continue").asArg :: Nil) ne_:: Nil, )(true, false, false), Match( - isContinue.asSimpleRef(ErasedType.Primitive(PrimitiveType.Bool)), + isContinue.asSimpleRef, (Case.Lit(Tree.BoolLit(true)) -> Continue(label)) :: Nil, - S(Assign(result, bodyResult.asSimpleRef(N), Break(label))), + S(Assign(result, bodyResult.asSimpleRef, Break(label))), End("label continue-sentinel dispatch") ) ) @@ -683,7 +683,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): label, loop = hasLocalContinue(body), bodyBlock, - k(result.asSimpleRef(N)) + k(result.asSimpleRef) ) case st.Break(label, result, value) => value match @@ -711,7 +711,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ErrorReport( msg"Expected arguments for builtin operator '${sym.nme}'" -> t.toLoc :: Nil, S(arg), source = Diagnostic.Source.Compilation) - k(sym.asSimpleRef(N).withLocOf(ref)) + k(sym.asSimpleRef.withLocOf(ref)) case st.Tup(Fld(FldFlags.benign(), arg, N) :: Nil) => if !sym.unary then raise: ErrorReport( @@ -719,7 +719,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): source = Diagnostic.Source.Compilation) subTerm(arg): ar => val target = wasmIntrinsicPath(sym, unary = true) - .getOrElse(sym.asSimpleRef(N).withLocOf(ref)) + .getOrElse(sym.asSimpleRef.withLocOf(ref)) k(Call(target, (Arg(N, ar) :: Nil) ne_:: Nil)(true, false, false)) case st.Tup(Fld(FldFlags.benign(), arg1, N) :: Fld(FldFlags.benign(), arg2, N) :: Nil) => if !sym.binary then raise: @@ -741,7 +741,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ar1, (Case.Lit(posLit) -> term_nonTail(arg2)(Assign(ts, _, End()))) :: Nil, S(Assign(ts, Value.Lit(negLit), End())), - k(ts.asSimpleRef(N)), + k(ts.asSimpleRef), ) sym match case State.andSymbol => mkBooleanMatch(true) @@ -749,7 +749,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case _ => subTerm_nonTail(arg2): ar2 => val target = wasmIntrinsicPath(sym, unary = false) - .getOrElse(sym.asSimpleRef(N).withLocOf(ref)) + .getOrElse(sym.asSimpleRef.withLocOf(ref)) k(Call(target, (Arg(N, ar1) :: Arg(N, ar2) :: Nil) ne_:: Nil)(true, false, false)) case _ => fail: ErrorReport( @@ -789,21 +789,21 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): instantiated match case t if instantiatedResolvedBms.exists(_ is ctx.builtins.js.bitand) => - conclude(State.runtimeSymbol.asSimpleRef(N).selN(Tree.Ident("bitand"))) + conclude(State.runtimeSymbol.asSimpleRef.selN(Tree.Ident("bitand"))) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.js.bitnot) => - conclude(State.runtimeSymbol.asSimpleRef(N).selN(Tree.Ident("bitnot"))) + conclude(State.runtimeSymbol.asSimpleRef.selN(Tree.Ident("bitnot"))) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.js.bitor) => - conclude(State.runtimeSymbol.asSimpleRef(N).selN(Tree.Ident("bitor"))) + conclude(State.runtimeSymbol.asSimpleRef.selN(Tree.Ident("bitor"))) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.js.shl) => - conclude(State.runtimeSymbol.asSimpleRef(N).selN(Tree.Ident("shl"))) + conclude(State.runtimeSymbol.asSimpleRef.selN(Tree.Ident("shl"))) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.js.try_catch) => - conclude(State.runtimeSymbol.asSimpleRef(N).selN(Tree.Ident("try_catch"))) + conclude(State.runtimeSymbol.asSimpleRef.selN(Tree.Ident("try_catch"))) case t if t.resolvedSym.exists { case sym: BlockMemberSymbol => wasmIntrinsicSymbols.contains(sym) case _ => false } => val sym = t.resolvedSym.get.asInstanceOf[BlockMemberSymbol] - conclude(State.wasmSymbol.asSimpleRef(N).selN(Tree.Ident(sym.nme))) + conclude(State.wasmSymbol.asSimpleRef.selN(Tree.Ident(sym.nme))) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.debug.printStack) => if !config.effectHandlers.exists(_.debug) then return fail: @@ -811,7 +811,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): msg"Debugging functions are not enabled" -> t.toLoc :: Nil, source = Diagnostic.Source.Compilation) - conclude(State.runtimeSymbol.asSimpleRef(N).selSN("raisePrintStackEffect").withLocOf(baseF)) + conclude(State.runtimeSymbol.asSimpleRef.selSN("raisePrintStackEffect").withLocOf(baseF)) case t if instantiatedResolvedBms.exists(_ is ctx.builtins.scope.locally) => // scope.locally only applies to the innermost call; extra args are applied on top if allArgs.length > 1 then @@ -867,7 +867,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): subTerms(as): asr => HandleBlock(lhs, resSym, par, asr, cls, handlers, inScopedBlock(returnedTerm(bod)), - k(resSym.asSimpleRef(N))) + k(resSym.asSimpleRef)) case st.Blk(sts, res) => block(sts, R(res), inStmtPos = inStmtPos)(k) case Assgn(lhs, rhs) => lhs match @@ -961,7 +961,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): TryBlock( subTerm_nonTail(sub)(p => Assign(l, p, End())), subTerm_nonTail(finallyDo)(_ => End()), - k(l.asSimpleRef(N)) + k(l.asSimpleRef) ) case Quoted(body) => quote(body)(k) @@ -1012,13 +1012,13 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): // subTerm(t)(k) def setupTerm(name: Str, args: Ls[Path])(k: Result => Block)(using LoweringCtx): Block = - k(Instantiate(mut = false, State.termSymbol.asSimpleRef(N).selSN(name), args.map(_.asArg) :: Nil)) + k(Instantiate(mut = false, State.termSymbol.asSimpleRef.selSN(name), args.map(_.asArg) :: Nil)) def setupQuotedKeyword(kw: Str): Path = - State.termSymbol.asSimpleRef(N).selSN("Keyword").selSN(kw) + State.termSymbol.asSimpleRef.selSN("Keyword").selSN(kw) def setupSymbol(symbol: Local)(k: Result => Block)(using LoweringCtx): Block = - k(Instantiate(mut = false, State.termSymbol.asSimpleRef(N).selSN("Symbol"), + k(Instantiate(mut = false, State.termSymbol.asSimpleRef.selSN("Symbol"), (Value.Lit(Tree.StrLit(symbol.nme)).asArg :: Nil) :: Nil)) def quotePattern(p: FlatPattern)(k: Result => Block)(using LoweringCtx): Block = p match @@ -1037,20 +1037,20 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): blockBuilder.assign(l1, r1) .chain(b => quotePattern(pattern)(r2 => Assign(l2, r2, b))) .chain(b => quoteSplit(continuation)(r3 => Assign(l3, r3, b))) - .chain(b => setupTerm("Branch", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef(N)))(r4 => Assign(l4, r4, b))) + .chain(b => setupTerm("Branch", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef))(r4 => Assign(l4, r4, b))) .chain(b => quoteSplit(tail)(r5 => Assign(l5, r5, b))) - .rest(setupTerm("Cons", (l4 :: l5 :: Nil).map(s => s.asSimpleRef(N)))(k)) + .rest(setupTerm("Cons", (l4 :: l5 :: Nil).map(s => s.asSimpleRef))(k)) case Split.Let(sym, term, tail) => setupSymbol(sym): r1 => loweringCtx.collectScopedSym(sym) val l1, l2, l3 = loweringCtx.registerTempSymbol(N) blockBuilder.assign(l1, r1) - .chain(b => setupTerm("Ref", l1.asSimpleRef(N) :: Nil)(r => Assign(sym, r, b))) + .chain(b => setupTerm("Ref", l1.asSimpleRef :: Nil)(r => Assign(sym, r, b))) .chain(b => quote(term)(r2 => Assign(l2, r2, b))) .chain(b => quoteSplit(tail)(r3 => Assign(l3, r3, b))) - .rest(setupTerm("Let", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef(N)))(k)) + .rest(setupTerm("Let", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef))(k)) case Split.Else(default) => quote(default): r => val l = loweringCtx.registerTempSymbol(N) - Assign(l, r, setupTerm("Else", l.asSimpleRef(N) :: Nil)(k)) + Assign(l, r, setupTerm("Else", l.asSimpleRef :: Nil)(k)) case Split.End => setupTerm("End", Nil)(k) case Split.LetSplit(sym, tail) => setupSymbol(sym): r1 => loweringCtx.collectScopedSym(sym) @@ -1064,7 +1064,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): lazy val setupFilename: Path = val state = summon[State] - state.importSymbol.asSimpleRef(N).selSN("meta").selSN("url") + state.importSymbol.asSimpleRef.selSN("meta").selSN("url") def quote(t: st)(k: Result => Block)(using LoweringCtx): Block = t match case Lit(lit) => @@ -1075,14 +1075,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Resolved(Ref(sym), disamb) => sym match case sym: BlockMemberSymbol => k(sym.asMemberRef(disamb)) - case sym: (LocalVarSymbol | BuiltinSymbol) => k(sym.asSimpleRef(N)) + case sym: (LocalVarSymbol | BuiltinSymbol) => k(sym.asSimpleRef) case sym => lastWords(s"Unexpected symbol kind ${sym.getClass.getSimpleName}: $sym") case Ref(sym) => k(sym.asPath) case SynthSel(Ref(sym: ModuleOrObjectSymbol), name) => // Local cross-stage references setupSymbol(sym): r1 => val l1, l2 = loweringCtx.registerTempSymbol(N) - Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef(N) :: setupFilename :: Value.Lit(syntax.Tree.UnitLit(false)) :: Nil)(r2 => - Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef(N) :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) + Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.Lit(syntax.Tree.UnitLit(false)) :: Nil)(r2 => + Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) )) case SynthSel(Ref(sym: BlockMemberSymbol), name) => // Multi-file cross-stage references if config.qqEnabled then fail: @@ -1097,8 +1097,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val basePath = base.up val targetPath = filename val relPath = targetPath.relativeTo(basePath).map(_.toString).getOrElse(targetPath.toString) - Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef(N) :: setupFilename :: Value.Lit(syntax.Tree.StrLit(relPath)) :: Nil)(r2 => - Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef(N) :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) + Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.Lit(syntax.Tree.StrLit(relPath)) :: Nil)(r2 => + Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) )) case _ => fail: ErrorReport( @@ -1114,13 +1114,13 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Assign( arr, Tuple(mut = false, ds.reverse.map(_.asArg)), - Assign(l, r, setupTerm("Lam", arr.asSimpleRef(N) :: l.asSimpleRef(N) :: Nil)(k))) + Assign(l, r, setupTerm("Lam", arr.asSimpleRef :: l.asSimpleRef :: Nil)(k))) case sym :: rest => loweringCtx.collectScopedSym(sym) setupSymbol(sym): r => val l = loweringCtx.registerTempSymbol(N) - Assign(l, r, setupTerm("Ref", l.asSimpleRef(N) :: Nil): r1 => - Assign(sym, r1, rec(rest, l.asSimpleRef(N) :: ds)(k))) + Assign(l, r, setupTerm("Ref", l.asSimpleRef :: Nil): r1 => + Assign(sym, r1, rec(rest, l.asSimpleRef :: ds)(k))) rec(params.params.map(_.sym), Nil)(k) // TODO: restParam? case App(lhs, Tup(rhs)) => quote(lhs): r1 => def rec(es: Ls[Elem], xs: Ls[Path])(k: Result => Block): Block = es match @@ -1129,14 +1129,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Assign( arrSym, Tuple(mut = false, xs.reverse.map(_.asArg)), - setupTerm("Tup", arrSym.asSimpleRef(N) :: Nil): r2 => + setupTerm("Tup", arrSym.asSimpleRef :: Nil): r2 => val l1 = loweringCtx.registerTempSymbol(N) val l2 = loweringCtx.registerTempSymbol(N) - Assign(l1, r1, Assign(l2, r2, setupTerm("App", l1.asSimpleRef(N) :: l2.asSimpleRef(N) :: Nil)(k))) + Assign(l1, r1, Assign(l2, r2, setupTerm("App", l1.asSimpleRef :: l2.asSimpleRef :: Nil)(k))) ) case Fld(_, t, _) :: rest => quote(t): r2 => val l = loweringCtx.registerTempSymbol(N) - Assign(l, r2, rec(rest, l.asSimpleRef(N) :: xs)(k)) + Assign(l, r2, rec(rest, l.asSimpleRef :: xs)(k)) case Spd(eager, term) :: rest => fail: ErrorReport( @@ -1152,17 +1152,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N) val arrSym = loweringCtx.registerTempSymbol(N, "arr") blockBuilder.assign(l1, r1) - .chain(b => setupTerm("Ref", l1.asSimpleRef(N) :: Nil)(r => Assign(sym, r, b))) + .chain(b => setupTerm("Ref", l1.asSimpleRef :: Nil)(r => Assign(sym, r, b))) .chain(b => quote(rhs)(r2 => Assign(l2, r2, b))) .chain(b => quote(res)(r3 => Assign(l3, r3, b))) - .chain(b => setupTerm("LetDecl", l1.asSimpleRef(N) :: Nil)(r4 => Assign(l4, r4, b))) - .chain(b => setupTerm("DefineVar", l1.asSimpleRef(N) :: l2.asSimpleRef(N) :: Nil)(r5 => Assign(l5, r5, b))) - .assign(arrSym, Tuple(mut = false, (l4 :: l5 :: Nil).map(s => s.asSimpleRef(N).asArg))) - .rest(setupTerm("Blk", arrSym.asSimpleRef(N) :: l3.asSimpleRef(N) :: Nil)(k)) + .chain(b => setupTerm("LetDecl", l1.asSimpleRef :: Nil)(r4 => Assign(l4, r4, b))) + .chain(b => setupTerm("DefineVar", l1.asSimpleRef :: l2.asSimpleRef :: Nil)(r5 => Assign(l5, r5, b))) + .assign(arrSym, Tuple(mut = false, (l4 :: l5 :: Nil).map(s => s.asSimpleRef.asArg))) + .rest(setupTerm("Blk", arrSym.asSimpleRef :: l3.asSimpleRef :: Nil)(k)) } case IfLike(_, IfLikeForm.ReturningIf, split) => quoteSplit(split.getExpandedSplit): r => val l = loweringCtx.registerTempSymbol(N) - Assign(l, r, setupTerm("IfLike", setupQuotedKeyword("If") :: l.asSimpleRef(N) :: Nil)(k)) + Assign(l, r, setupTerm("IfLike", setupQuotedKeyword("If") :: l.asSimpleRef :: Nil)(k)) case Unquoted(body) => term(body)(k) case _ => fail: ErrorReport( @@ -1233,7 +1233,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Assign( rcdSym, Record(mut = false, fsr.reverse), - k((Arg(N, rcdSym.asSimpleRef(N)) :: asr).reverse))) + k((Arg(N, rcdSym.asSimpleRef) :: asr).reverse))) inline def plainArgs(ts: Ls[st])(k: Ls[Arg] => Block)(using LoweringCtx): Block = @@ -1262,7 +1262,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Define(lamDef, k(lamDef.asPath)) case r => val l = loweringCtx.registerTempSymbol(N) - Assign(l, r, k(l.asSimpleRef(N))) + Assign(l, r, k(l.asSimpleRef)) def program(main: st.Blk): Program = @@ -1416,12 +1416,12 @@ trait LoweringSelSanityChecks(using Config, TL, Raise, State) blockBuilder .assign(selRes, Select(p, nme)(disamb)) .assign(State.noSymbol, Select(p, Tree.Ident(nme.name+"$__checkNotMethod"))(N)) - .ifthen(selRes.asSimpleRef(N), + .ifthen(selRes.asSimpleRef, Case.Lit(syntax.Tree.UnitLit(false)), Throw(Instantiate(mut = false, Select(State.globalThisSymbol.asThis, Tree.Ident("Error"))(N), (Value.Lit(syntax.Tree.StrLit(s"Access to required field '${nme.name}' yielded 'undefined'")).asArg :: Nil) :: Nil)) ) - .rest(k(selRes.asSimpleRef(N))) + .rest(k(selRes.asSimpleRef)) @@ -1442,9 +1442,9 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) extension (k: Block => Block) def |>: (b: Block): Block = k(b) - private val traceLogFn = State.runtimeSymbol.asSimpleRef(N).selSN("TraceLogger").selSN("log") - private val traceLogIndentFn = State.runtimeSymbol.asSimpleRef(N).selSN("TraceLogger").selSN("indent") - private val traceLogResetFn = State.runtimeSymbol.asSimpleRef(N).selSN("TraceLogger").selSN("resetIndent") + private val traceLogFn = State.runtimeSymbol.asSimpleRef.selSN("TraceLogger").selSN("log") + private val traceLogIndentFn = State.runtimeSymbol.asSimpleRef.selSN("TraceLogger").selSN("indent") + private val traceLogResetFn = State.runtimeSymbol.asSimpleRef.selSN("TraceLogger").selSN("resetIndent") private val strConcatFn = selFromGlobalThis("String", "prototype", "concat", "call") private val inspectFn = selFromGlobalThis("util", "inspect") @@ -1477,34 +1477,34 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) val psSymArgs = psInspectedSyms.zipWithIndex.foldRight[Ls[Arg]](Arg(N, Value.Lit(Tree.StrLit(")"))) :: Nil): case (((s, p), i), acc) => if i == psInspectedSyms.length - 1 - then Arg(N, s.asSimpleRef(ErasedType.Primitive(PrimitiveType.Str))) :: acc - else Arg(N, s.asSimpleRef(ErasedType.Primitive(PrimitiveType.Str))) :: Arg(N, Value.Lit(Tree.StrLit(", "))) :: acc + then Arg(N, s.asSimpleRef) :: acc + else Arg(N, s.asSimpleRef) :: Arg(N, Value.Lit(Tree.StrLit(", "))) :: acc val tmp1, tmp2, tmp3 = loweringCtx.registerTempSymbol(N) assignStmts(psInspectedSyms.map: (pInspectedSym, pSym) => - pInspectedSym -> pureCall(inspectFn, Arg(N, pSym.asSimpleRef(N)) :: Nil) + pInspectedSym -> pureCall(inspectFn, Arg(N, pSym.asSimpleRef) :: Nil) *) |>: assignStmts( enterMsgSym -> pureCall( strConcatFn, Arg(N, Value.Lit(Tree.StrLit(s"CALL ${name.getOrElse("[arrow function]")}("))) :: psSymArgs ), - tmp1 -> pureCall(traceLogFn, Arg(N, enterMsgSym.asSimpleRef(ErasedType.Primitive(PrimitiveType.Str))) :: Nil), + tmp1 -> pureCall(traceLogFn, Arg(N, enterMsgSym.asSimpleRef) :: Nil), prevIndentLvlSym -> pureCall(traceLogIndentFn, Nil) ) |>: term(bod)(r => assignStmts( resSym -> r, - resInspectedSym -> pureCall(inspectFn, Arg(N, resSym.asSimpleRef(N)) :: Nil), + resInspectedSym -> pureCall(inspectFn, Arg(N, resSym.asSimpleRef) :: Nil), retMsgSym -> pureCall( strConcatFn, - Arg(N, Value.Lit(Tree.StrLit("=> "))) :: Arg(N, resInspectedSym.asSimpleRef(ErasedType.Primitive(PrimitiveType.Str))) :: Nil + Arg(N, Value.Lit(Tree.StrLit("=> "))) :: Arg(N, resInspectedSym.asSimpleRef) :: Nil ), - tmp2 -> pureCall(traceLogResetFn, Arg(N, prevIndentLvlSym.asSimpleRef(N)) :: Nil), - tmp3 -> pureCall(traceLogFn, Arg(N, retMsgSym.asSimpleRef(ErasedType.Primitive(PrimitiveType.Str))) :: Nil) + tmp2 -> pureCall(traceLogResetFn, Arg(N, prevIndentLvlSym.asSimpleRef) :: Nil), + tmp3 -> pureCall(traceLogFn, Arg(N, retMsgSym.asSimpleRef) :: Nil) ) |>: - Ret(resSym.asSimpleRef(N)) + Ret(resSym.asSimpleRef) ) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index a815632645..7b80670ce8 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -164,7 +164,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): else doc def print(value: Value)(using Scope): Document = value match - case Value.SimpleRef(l, _) => print(l) + case Value.SimpleRef(l) => print(l) case Value.MemberRef(bms, disamb) => showSymbol(bms.nme, S(disamb)) case Value.This(sym) if sym === State.globalThisSymbol => showSymbol(sym.nme, S(sym.asDefnSym)) case Value.This(sym) => doc"${print(sym)}.this" diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala index 0f85824f2f..62c13f56ea 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala @@ -55,7 +55,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n def assign(res: Result, symName: Str = "tmp")(k: Path => Block): Block = // TODO: skip assignment if res: Path? val sym = new TempSymbol(N, symName) - Scoped(Set(sym), Assign(sym, res, k(sym.asSimpleRef(N)))) + Scoped(Set(sym), Assign(sym, res, k(sym.asSimpleRef))) def tuple(elems: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = assign(Tuple(false, elems.map(asArg)), symName)(k) @@ -66,8 +66,8 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n // helpers for instrumenting Block - def blockMod(name: Str) = summon[State].blockSymbol.asSimpleRef(N).selSN(name) - def optionMod(name: Str) = summon[State].optionSymbol.asSimpleRef(N).selSN(name) + def blockMod(name: Str) = summon[State].blockSymbol.asSimpleRef.selSN(name) + def optionMod(name: Str) = summon[State].optionSymbol.asSimpleRef.selSN(name) def blockCtor(name: Str, args: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = call(blockMod(name), args, true, symName)(k) @@ -169,7 +169,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n // rulePath ctx.get(p).map(k).getOrElse: p match - case Value.SimpleRef(l, _) => + case Value.SimpleRef(l) => transformSymbol(l): sym => blockCtor("ValueSimpleRef", Ls(sym), "var")(k) case Value.MemberRef(bms, disamb) => @@ -213,7 +213,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n raise(ErrorReport(msg"Instantiate with multiple argument lists not supported in staged module." -> r.toLoc :: Nil)) End() // desugar Runtime.Tuple.get into Select - case Call(fun, Ls(Arg(_, scrut), Arg(_, Value.Lit(Tree.IntLit(idx)))) :: _) if fun == Value.SimpleRef(State.runtimeSymbol, N).selSN("Tuple").selSN("get") => + case Call(fun, Ls(Arg(_, scrut), Arg(_, Value.Lit(Tree.IntLit(idx)))) :: _) if fun == Value.SimpleRef(State.runtimeSymbol).selSN("Tuple").selSN("get") => transformPath(Select(scrut, Tree.Ident(idx.toString()))(N))(k) case Call(fun, argss) => val stagedFunPath = fun match diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala index f424306c89..55b88d3c89 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala @@ -279,7 +279,7 @@ private def matchChainToSwitch(m: MatchChain): SwitchLike = object SpecializedSwitch: def unapply(b: Block) = b match - case m @ Match(scrut = r @ Value.SimpleRef(l, _)) => + case m @ Match(scrut = r @ Value.SimpleRef(l)) => val chain = findMatchChainRec(m, r, Nil) val SwitchLike(scrut, cases, dflt, rest) = matchChainToSwitch(chain) if cases.size < 2 then N diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala index 073e55ac09..bb6ab15766 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala @@ -12,7 +12,7 @@ import hkmc2.codegen.HandlerLowering.FnOrCls class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: StackSafetyMap)(using State, Config): private val STACK_DEPTH_IDENT: Tree.Ident = Tree.Ident("stackDepth") - private val runtimePath: Path = State.runtimeSymbol.asSimpleRef(N) + private val runtimePath: Path = State.runtimeSymbol.asSimpleRef private val checkDepthPath: Path = runtimePath.selN(Tree.Ident("checkDepth")) private val runStackSafePath: Path = runtimePath.selN(Tree.Ident("runStackSafe")) private val stackDepthPath: Path = runtimePath.selN(STACK_DEPTH_IDENT) @@ -20,7 +20,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S private def intLit(n: BigInt) = Value.Lit(Tree.IntLit(n)) private def op(op: String, a: Path, b: Path) = - Call(State.builtinOpsMap(op).asSimpleRef(N), (a.asArg :: b.asArg :: Nil) ne_:: Nil)(true, false, false) + Call(State.builtinOpsMap(op).asSimpleRef, (a.asArg :: b.asArg :: Nil) ne_:: Nil)(true, false, false) // Increases the stack depth, assigns the call to a value, then decreases the stack depth // then binds that value to a desired block @@ -30,7 +30,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S else blockBuilder .assign(sym, res) - .assignFieldN(runtimePath, STACK_DEPTH_IDENT, curDepth.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int))) + .assignFieldN(runtimePath, STACK_DEPTH_IDENT, curDepth.asSimpleRef) .rest(f(sym.asPath)) def wrapStackSafe(body: Block, resSym: Local, rest: Block) = @@ -48,7 +48,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S // Rewrites anything that can contain a Call to increase the stack depth def transform(b: Block, curDepth: => LocalVarSymbol, isTopLevel: Bool = false): Block = def usesStack(r: Result) = r match - case Call(Value.SimpleRef(_: BuiltinSymbol, _), _) => false + case Call(Value.SimpleRef(_: BuiltinSymbol), _) => false case c: Call if !c.mayRaiseEffects => false // a call can only trigger a stack delay if it can raise effects case _: Call | _: Instantiate => true case _ => false @@ -98,7 +98,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S new BlockTraverserShallow: applyBlock(b) override def applyResult(r: Result): Unit = r match - case Call(Value.SimpleRef(_: BuiltinSymbol, _), _) => () + case Call(Value.SimpleRef(_: BuiltinSymbol), _) => () case _: Call | _: Instantiate => trivial = false case _ => () trivial diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala index fadc239889..da41b40027 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala @@ -206,10 +206,10 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends case _ => super.applyPath(p)(k) override def applyValue(v: Value)(k: Value => Block): Block = v match - case Value.SimpleRef(l, _) => + case Value.SimpleRef(l) => mapping.get(l) match case Some(newSym: (LocalVarSymbol | BuiltinSymbol)) => - k(newSym.asSimpleRef(N)) + k(newSym.asSimpleRef) case _ => super.applyValue(v)(k) case Value.MemberRef(bms, disamb) => mapping.get(bms) match @@ -219,7 +219,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends case Some(nd) => lastWords(s"unexpected symbol kind for disamb: ${nd}") case N => lastWords(s"unexpected lack of refreshed disamb symbol for $disamb") k(newBms.asMemberRef(newDisamb)) - case Some(newSym: (LocalVarSymbol | TempSymbol)) => k(newSym.asSimpleRef(N)) + case Some(newSym: (LocalVarSymbol | TempSymbol)) => k(newSym.asSimpleRef) case _ => super.applyValue(v)(k) case Value.This(sym) => mapping.get(sym) match diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 213ed666b6..3bcf8a89ba 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -370,13 +370,13 @@ class TailRecOpt(using State, TL, Raise): // `assignedSyms` contains all of the param symbols that have been assigned to // before the current assignment, and thus references to them must be rewritten // to point to a temporary variable. - case Value.SimpleRef(l: VarSymbol, _) => assignedSyms.get(l) match + case Value.SimpleRef(l: VarSymbol) => assignedSyms.get(l) match case S(v) => val tmpSym = v.force_! // Adding this to `requiredTmps` will make sure we set the temporary variable // to the current variable at the start of the rewritten call. requiredTmps += (l, tmpSym) - k(tmpSym.asSimpleRef(N)) + k(tmpSym.asSimpleRef) case _ => super.applyValue(v)(k) case _ => super.applyValue(v)(k) @@ -387,7 +387,7 @@ class TailRecOpt(using State, TL, Raise): val selfAssigns = argListResults.flatMap: (_, thisParamSyms, args, argsRes) => argsRes match case CallArgsResult.Success(res) => thisParamSyms.zip(res).collect: - case (sym1, Value.SimpleRef(sym2, _)) if sym1 === sym2 => sym1 + case (sym1, Value.SimpleRef(sym2)) if sym1 === sym2 => sym1 case CallArgsResult.ForceSpread => List.empty assignedSyms --= selfAssigns @@ -413,7 +413,7 @@ class TailRecOpt(using State, TL, Raise): // in `rewrite`. Also note that `paramRewriter` will add all encountered rewritten variables // to `requiredTmps`. val ret = paramRewriter.applyResult(res)(Assign(sym, _, acc)) match - case Assign(sym, Value.SimpleRef(sym1, _), rest) if sym === sym1 => rest // avoid useless assignments + case Assign(sym, Value.SimpleRef(sym1), rest) if sym === sym1 => rest // avoid useless assignments case x => x ret case CallArgsResult.ForceSpread => @@ -433,7 +433,7 @@ class TailRecOpt(using State, TL, Raise): // Main args def mainArgs(rest: List[Path]) = (0 until paramList.size).toList.foldRight(rest): - case (n, acc) => DynSelect(tupleSym.asSimpleRef(N), Value.Lit(Tree.IntLit(n)), true) :: acc + case (n, acc) => DynSelect(tupleSym.asSimpleRef, Value.Lit(Tree.IntLit(n)), true) :: acc // If the rest param exists, append a slice val (initialBlk: (Block => Block), pathList: List[Path]) = @@ -441,10 +441,10 @@ class TailRecOpt(using State, TL, Raise): val sliceResSym = TempSymbol(N, "sliceRes") // runtime.Tuple.slice(tupleSym, paramList.length, 0) val sliceRes = Call( - State.runtimeSymbol.asSimpleRef(N) + State.runtimeSymbol.asSimpleRef .sel(Tree.Ident("Tuple"), State.tupleSymbol) .sel(Tree.Ident("slice"), State.tupleSliceSymbol), - (tupleSym.asSimpleRef(N).asArg + (tupleSym.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(paramList.length)).asArg :: Value.Lit(Tree.IntLit(0)).asArg :: Nil) ne_:: Nil @@ -452,7 +452,7 @@ class TailRecOpt(using State, TL, Raise): val blk = blockBuilder .assignScoped(tupleSym, tupleRes) .assignScoped(sliceResSym, sliceRes) - (blk, mainArgs(sliceResSym.asSimpleRef(N) :: Nil)) + (blk, mainArgs(sliceResSym.asSimpleRef :: Nil)) else (blockBuilder.assignScoped(tupleSym, tupleRes), mainArgs(Nil)) end val @@ -466,7 +466,7 @@ class TailRecOpt(using State, TL, Raise): Scoped( requiredTmps.values.toSet, requiredTmps.toList.foldRight(assignments): - case ((v, l), acc) => Assign(l, v.asSimpleRef(N), acc)) + case ((v, l), acc) => Assign(l, v.asSimpleRef, acc)) // Not a tail call case _ => super.applyBlock(b) @@ -474,7 +474,7 @@ class TailRecOpt(using State, TL, Raise): // Rewrite the result with symbols pointing to the merged function parameters and possibly the copied parameters (see `copiedParams`). val blk = applyBlock(symRewriter.applyBlock(b)) val withCopied = copiedParamSyms.toArray.sortBy(_._1.uid).foldRight(blk): - case ((ogParam, copiedParam), accBlk) => Assign(copiedParam, paramSymsArr(paramsIdxes(ogParam)).asSimpleRef(N), accBlk) + case ((ogParam, copiedParam), accBlk) => Assign(copiedParam, paramSymsArr(paramsIdxes(ogParam)).asSimpleRef, accBlk) Scoped(copiedParamSyms.map(_._2).toSet, withCopied) val arms = funs.map: f => @@ -482,7 +482,7 @@ class TailRecOpt(using State, TL, Raise): val switch = if arms.length === 1 then arms.head._2 - else Match(curIdSym.asSimpleRef(ErasedType.Primitive(PrimitiveType.Int)), arms, N, End()) + else Match(curIdSym.asSimpleRef, arms, N, End()) val loop = Label(loopSym, true, switch, End()) @@ -493,7 +493,7 @@ class TailRecOpt(using State, TL, Raise): val rewrittenFuns = if funs.size === 1 then Nil else funs.map: f => - val paramArgs = getParamSyms(f).map(s => s.asSimpleRef(N).asArg) + val paramArgs = getParamSyms(f).map(s => s.asSimpleRef.asArg) val args = Value.Lit(Tree.IntLit(dSymIds(f.dSym))).asArg :: paramArgs @@ -525,7 +525,7 @@ class TailRecOpt(using State, TL, Raise): owner, loopBms, loopDSym, PlainParamList(params) :: Nil, loop)(N, annotations = Annot.Private :: Nil) - val paramArgs = getParamSyms(f).map(s => s.asSimpleRef(N).asArg) + val paramArgs = getParamSyms(f).map(s => s.asSimpleRef.asArg) val internalSel = owner match case Some(value) => Select(value.asThis, Tree.Ident(loopBms.nme))(S(loopDSym)) case None => loopBms.asMemberRef(loopDSym) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala index 42cabdf482..bbe13057ee 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala @@ -73,7 +73,7 @@ class UsedVarAnalyzer(b: Block, scopeData: ScopeData)(using State): case _ => super.applyBlock(b) override def applyPath(p: Path): Unit = p match - case Value.SimpleRef(_: BuiltinSymbol, _) => super.applyPath(p) + case Value.SimpleRef(_: BuiltinSymbol) => super.applyPath(p) case RefOfBms(_, SDSym(dSym), _) => val node = scopeData.getNode(dSym) node.obj match diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala index 6d43cab888..e184e54240 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala @@ -68,7 +68,7 @@ class WorkerWrapper workerBody, )(fun.configOverride, withoutInline(fun.annotations)) val workerArgs = fun.params.flatMap(_.params).map: param => - Arg(N, param.sym.asSimpleRef(N)) + Arg(N, param.sym.asSimpleRef) val wrapperBody = Return( Call(worker.asPath, workerArgs ne_:: Nil)(isMlsFun = true, mayRaiseEffects = true, explicitTailCall = false), ) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala index d14e54bcb0..95a8d49cd6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala @@ -272,7 +272,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): override def applyValue(v: Value): Unit = v match - case Value.SimpleRef(l, _) if !inCtx(l) => freeVars.add(l) + case Value.SimpleRef(l) if !inCtx(l) => freeVars.add(l) case Value.MemberRef(bms, _) if !inCtx(bms) && bms.asClsLike.isEmpty => freeVars.add(bms) case _ => super.applyValue(v) @@ -395,7 +395,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): r match case s@TrackableSelect(from, _, _) => if branchSelSyms.isDefinedAt(s.uid.concreteId) then - k(branchSelSyms(s.uid.concreteId).asSimpleRef(N)) + k(branchSelSyms(s.uid.concreteId).asSimpleRef) else if solver.finalDtorSrcs.contains(s.uid.concreteId) then applyPath(from)(k) else @@ -414,7 +414,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val ctorInfo = solver.fusingCtorInfo(ctor.uid.concreteId) val idx = ctorInfo.args.unzip._1.indexOf(field) val fieldSyms = mkCtorFieldSyms(ctor.uid.concreteId) - args.zip(fieldSyms).foldRight(k(fieldSyms(idx).asSimpleRef(N))): + args.zip(fieldSyms).foldRight(k(fieldSyms(idx).asSimpleRef)): case (Arg(N, a) -> s, rest) => applyPath(a): fusedField => Scoped(Set.single(s), Assign(s, fusedField, rest)) @@ -443,11 +443,11 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): Assign( lambdaSym, callBranchFun, - k(lambdaSym.asSimpleRef(N))) + k(lambdaSym.asSimpleRef)) ) case s@TrackableSelect(from, _, _) => if branchSelSyms.isDefinedAt(s.uid.concreteId) then - k(branchSelSyms(s.uid.concreteId).asSimpleRef(N)) + k(branchSelSyms(s.uid.concreteId).asSimpleRef) else if solver.finalDtorSrcs.contains(s.uid.concreteId) then applyPath(from)(k) else diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala index 25fa96e0a6..f1316f6d72 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala @@ -33,7 +33,7 @@ object FlowAnalysis: def getResult = resultIdToResult(resultId) def getReferredSym: Symbol = resultId.getResult match - case Value.SimpleRef(s, _) => s + case Value.SimpleRef(s) => s case Value.MemberRef(bms, _) => bms case Value.This(sym) => sym case e => lastWords(s"assumption failed: $e is not a SimpleRef, MemberRef, or ThisRef") @@ -90,7 +90,7 @@ type OriginId = ResultId | FunId /** Extracts the underlying symbol of a variable-like reference, for flow-tracking use. */ object TrackedSymOf: def unapply(p: Value.RefLike | Select)(using Elaborator.State): Opt[Symbol] = p match - case Value.SimpleRef(sym, _) => S(sym) + case Value.SimpleRef(sym) => S(sym) case Value.MemberRef(_, disamb) => S(disamb) case Value.This(sym) => S(sym) case s: Select => s.symbol.flatMap: selSym => @@ -116,8 +116,8 @@ object PossibleTrackableTupleSelect: def unapply(s: Result)(using eState: Elaborator.State): Opt[Value.SimpleRef -> Int] = s match case Call( - Select(Select(Value.SimpleRef(runtimeSym, _), Tree.Ident("Tuple")), Tree.Ident("get")), - (Arg(N, ref@Value.SimpleRef(scrut, _)) :: Arg(N, Value.Lit(Tree.IntLit(n))) :: Nil) :: Nil + Select(Select(Value.SimpleRef(runtimeSym), Tree.Ident("Tuple")), Tree.Ident("get")), + (Arg(N, ref@Value.SimpleRef(scrut)) :: Arg(N, Value.Lit(Tree.IntLit(n))) :: Nil) :: Nil ) if runtimeSym is eState.runtimeSymbol => S(ref -> n.toInt) case _ => N @@ -125,7 +125,7 @@ object TrackableSelect: def unapply(s: Result)(using pre: FlowPreAnalyzer, eState: Elaborator.State): Opt[(from: Path, field: SelField, owner: CtorCls)] = given fState: FlowAnalysis.State = pre.fState s match - case sel@PossibleTrackableTupleSelect((ref@Value.SimpleRef(scrut, _)) -> ith) => + case sel@PossibleTrackableTupleSelect((ref@Value.SimpleRef(scrut)) -> ith) => pre.res.getEnclosingMatchesForSel(sel.uid).find(_._1.getReferredSym is scrut).flatMap: case (_, Some(tupSize: Int)) => S(ref, ith, tupSize) case _ => N @@ -146,7 +146,7 @@ object CtorRef: yield cls def unapply(p: Path)(using Elaborator.State): Opt[ClassSymbol | ModuleOrObjectSymbol] = p match - case Value.SimpleRef(sym, _) => classCtorSymbol(sym) + case Value.SimpleRef(sym) => classCtorSymbol(sym) case Value.MemberRef(_, disamb) => classCtorSymbol(disamb) orElse disamb.asCls orElse disamb.asObj case Value.This(sym) => classCtorSymbol(sym) orElse sym.asCls case s: Select => s.symbol.flatMap(classCtorSymbol) @@ -408,7 +408,7 @@ class FlowPreAnalyzer(val pgrm: Program)(using def isEnclosingMatchScrutSym(sym: Symbol): Boolean = ctx.exists: case InCtx.MtchBody(m, _) => m.scrut match - case Value.SimpleRef(s, _) => s is sym + case Value.SimpleRef(s) => s is sym case Value.MemberRef(bms, disamb) => disamb is sym case _ => false case _ => false @@ -559,7 +559,7 @@ class FlowPreAnalyzer(val pgrm: Program)(using case p: Path => applyPath(p) private def applyValueSimpleRef(v: Value.SimpleRef, recordAffinity: Bool) = - val Value.SimpleRef(l, _) = v + val Value.SimpleRef(l) = v l match case s: TermSymbol => recordRefInCaptures(s) @@ -583,7 +583,7 @@ class FlowPreAnalyzer(val pgrm: Program)(using case p@TrackableFieldSelect(qual, _ -> _) => res.selToCtxOfSel.addOne(p.uid -> ctxTracker.getAllCtx) qual match - case v@Value.SimpleRef(l, _) + case v@Value.SimpleRef(l) if ctxTracker.isEnclosingMatchScrutSym(l) => applyValueSimpleRef(v, recordAffinity = false) case v@Value.MemberRef(bms, disamb) @@ -595,7 +595,7 @@ class FlowPreAnalyzer(val pgrm: Program)(using case v: Value => applyValue(v) override def applyValue(v: Value): Unit = v match - case v@Value.SimpleRef(l, _) => applyValueSimpleRef(v, recordAffinity = true) + case v@Value.SimpleRef(l) => applyValueSimpleRef(v, recordAffinity = true) case v@Value.MemberRef(_, _) => applyValueMemberRef(v, recordAffinity = true) case Value.This(sym) => () case Value.Lit(lit) => () diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index 2e60402bcd..0cf4b49d72 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -179,25 +179,25 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: case Value.MemberRef(bms, disamb) => if disamb.shouldBeLifted then doc"${scope.lookup_!(bms, bms.toLoc)}.class" else scope.lookup_!(bms, r.toLoc) - case Value.SimpleRef(l: BuiltinSymbol, _) => + case Value.SimpleRef(l: BuiltinSymbol) => if l.nullary then l.nme else errExpr(msg"Illegal reference to builtin symbol '${l.nme}'") - case Value.SimpleRef(l: semantics.TermSymbol, _) => + case Value.SimpleRef(l: semantics.TermSymbol) => l.owner match case S(owner) => lastWords(s"Unexpected SimpleRef of TermSymbol with owner: `$l` (owner: `$owner`)") case N => scope.lookup_!(l, r.toLoc) - case Value.SimpleRef(l, _) => scope.lookup_!(l, r.toLoc) - case Call(Value.SimpleRef(l: BuiltinSymbol, _), (lhs :: rhs :: Nil) :: Nil) if !l.functionLike => + case Value.SimpleRef(l) => scope.lookup_!(l, r.toLoc) + case Call(Value.SimpleRef(l: BuiltinSymbol), (lhs :: rhs :: Nil) :: Nil) if !l.functionLike => if l.binary then val res = doc"${operand(lhs)} ${l.nme} ${operand(rhs)}" if needsParens(l.nme) then doc"(${res})" else res else errExpr(msg"Cannot call non-binary builtin symbol '${l.nme}'") - case Call(Value.SimpleRef(l: BuiltinSymbol, _), (rhs :: Nil) :: Nil) if !l.functionLike => + case Call(Value.SimpleRef(l: BuiltinSymbol), (rhs :: Nil) :: Nil) if !l.functionLike => if l.unary then val res = doc"${l.nme} ${operand(rhs)}" if needsParens(l.nme) then doc"(${res})" else res else errExpr(msg"Cannot call non-unary builtin symbol '${l.nme}'") - case Call(Value.SimpleRef(l: BuiltinSymbol, _), args :: Nil) => + case Call(Value.SimpleRef(l: BuiltinSymbol), args :: Nil) => if l.functionLike then val argsDoc = args.map(argument).mkDocument(", ") doc"${l.nme}(${argsDoc})" @@ -317,7 +317,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: val scrutSym = scrut.map(_.sym) b match case Match( - scrut_ @ Value.SimpleRef(scrutSym_, _), // The scrutinee is a ref. + scrut_ @ Value.SimpleRef(scrutSym_), // The scrutinee is a ref. (Case.Lit(Tree.IntLit(curVal_)), b) :: Nil, // There is only one case matching an int literal. S(End(_)) | N, rest // Default case exists and does nothing. ) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index da7a1fe74e..70bb4401d2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -225,7 +225,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): def parentFromPath(p: Path): Ls[Local] = p match case Value.MemberRef(bms, disamb) => fromMemToClass(disamb) :: Nil case Value.This(sym) => fromMemToClass(sym) :: Nil - case Value.SimpleRef(l, _) => + case Value.SimpleRef(l) => // TODO(Derppening): Check if this assertion holds bErrStop(msg"Expected parent to be a MemberRef") case _ => bErrStop(msg"Unsupported parent path ${p.toString()}") @@ -281,7 +281,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): private def bValue(v: Value)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = trace[Node](s"bValue { $v } begin", x => s"bValue end: ${x.show}"): v match - case Value.SimpleRef(l: TermSymbol, _) if l.owner.nonEmpty => + case Value.SimpleRef(l: TermSymbol) if l.owner.nonEmpty => k(l |> sr) case Value.MemberRef(bms, disamb) if bms.nme.isCapitalized => val v: Local = newTemp @@ -293,18 +293,18 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): val paramsList = PlainParamList( (0 until f.paramsSize).zip(tempSymbols).map((_n, sym) => Param(FldFlags.empty, sym, N, Modulefulness.none)).toList) - val app = Call(v, tempSymbols.map(x => Arg(N, x.asSimpleRef(N))).toList ne_:: Nil)(true, false, false) + val app = Call(v, tempSymbols.map(x => Arg(N, x.asSimpleRef)).toList ne_:: Nil)(true, false, false) bLam(Lambda(paramsList, Return(app))(Nil), S(bms.nme), N)(k) case None => k(ctx.findName(bms) |> sr) - case Value.SimpleRef(l, _) => + case Value.SimpleRef(l) => ctx.fn_ctx.get(l) match case Some(f) => val tempSymbols = (0 until f.paramsSize).map(x => newNamed("arg")) val paramsList = PlainParamList( (0 until f.paramsSize).zip(tempSymbols).map((_n, sym) => Param(FldFlags.empty, sym, N, Modulefulness.none)).toList) - val app = Call(v, tempSymbols.map(x => Arg(N, x.asSimpleRef(N))).toList ne_:: Nil)(true, false, false) + val app = Call(v, tempSymbols.map(x => Arg(N, x.asSimpleRef)).toList ne_:: Nil)(true, false, false) bLam(Lambda(paramsList, Return(app))(Nil), S(l.nme), N)(k) case None => k(ctx.findName(l) |> sr) @@ -379,7 +379,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): r match case Call(_, argss) if argss.sizeIs > 1 => bErrStop(msg"Calls with multiple argument lists are not yet supported in LLIR") - case Call(Value.SimpleRef(sym: BuiltinSymbol, _), argss) => + case Call(Value.SimpleRef(sym: BuiltinSymbol), argss) => bArgs(argss.flatten): case args: Ls[TrivialExpr] => val v: Local = newTemp @@ -421,7 +421,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): case args: Ls[TrivialExpr] => val v: Local = newTemp Node.LetCall(Ls(v), builtin, Expr.Literal(Tree.StrLit(mathPrimitive)) :: args, k(v |> sr)) - case Call(s @ Select(r @ Value.SimpleRef(sym, _), Tree.Ident(fld)), argss) if s.symbol.isDefined => + case Call(s @ Select(r @ Value.SimpleRef(sym), Tree.Ident(fld)), argss) if s.symbol.isDefined => bPath(r): case r => bArgs(argss.flatten): @@ -503,7 +503,7 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): summon[Ctx].def_acc += jpdef Node.Case(e, casesList, defaultCase) case Return(res) => bResult(res)(x => Node.Result(Ls(x))) - case Throw(Instantiate(false, Select(Value.SimpleRef(_, _), ident), + case Throw(Instantiate(false, Select(Value.SimpleRef(_), ident), Ls(Arg(N, Value.Lit(Tree.StrLit(e)))) :: Nil)) if ident.name === "Error" => Node.Panic(e) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 7522f3dc6f..69e93b7f3f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -926,7 +926,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: def splitSuperTail(block: Block): Opt[Block -> Ls[Arg]] = block match case End(_) => N - case Assign(lhs, Call(Value.SimpleRef(bs: BuiltinSymbol, _), argss), _: End) + case Assign(lhs, Call(Value.SimpleRef(bs: BuiltinSymbol), argss), _: End) if (lhs is State.noSymbol) && (bs is State.superSymbol) => S(End("") -> argss.flatten) @@ -1203,7 +1203,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: operands = Seq(ref.i31(i32.const(lit.offset)), ref.i31(i32.const(lit.byteLen))), returnTypes = Seq(Result(RefType.anyref)), ) - case Value.SimpleRef(l, _) => + case Value.SimpleRef(l) => singletonInfoFor(l) match case S(info) => singletonGlobalGet(info) case N => @@ -1236,7 +1236,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ), ) - case Call(Value.SimpleRef(l: BuiltinSymbol, _), lhs :: rhs :: Nil) if !l.functionLike => + case Call(Value.SimpleRef(l: BuiltinSymbol), lhs :: rhs :: Nil) if !l.functionLike => if l.binary then errExpr( Ls( @@ -1283,9 +1283,9 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ) case N => fun match - case Value.SimpleRef(l, _) => + case Value.SimpleRef(l) => val base = fun match - case Value.SimpleRef(l, _) => ctx.getFunc(l) + case Value.SimpleRef(l) => ctx.getFunc(l) case Value.MemberRef(l, _) => ctx.getFunc(l) case _ => N val baseFuncIdx = base match @@ -1486,7 +1486,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: /** Returns the intrinsic name if `path` refers to a builtin under `wasm`, or `N` otherwise. */ private def wasmIntrinsicName(path: Path): Opt[Str] = path match - case Select(Value.SimpleRef(sym, _), ident) if (sym eq State.wasmSymbol) && wasmIntrinsicNameSet.contains(ident.name) => + case Select(Value.SimpleRef(sym), ident) if (sym eq State.wasmSymbol) && wasmIntrinsicNameSet.contains(ident.name) => S(ident.name) case _ => N diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index b1cb610676..98424e1d2d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -367,7 +367,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C for (_, s) <- entries do LoweringCtx.loweringCtx.collectScopedSym(s) val objectSym = ctx.builtins.Object mkMatch( // checking that we have an object - Case.Cls(objectSym, BuiltinSymbol(objectSym.nme, false, false, true, false).asSimpleRef(codegen.ErasedType.AnyRef(false, objectSym))), + Case.Cls(objectSym, BuiltinSymbol(objectSym.nme, false, false, true, false).asSimpleRef), entries.foldRight(lowerSplit(tail, cont)): case ((fieldName, fieldSymbol), blk) => mkMatch( @@ -490,7 +490,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C if useNestedScoped then LoweringCtx.loweringCtx.getCollectedSym else Set.empty, mainBlock) // Embed the `body` into `Label` if the term is a `while`. - lazy val rest = if usesResTmp then k(l.asSimpleRef(N)) else k(lowering.unit) + lazy val rest = if usesResTmp then k(l.asSimpleRef) else k(lowering.unit) val block = if form === IfLikeForm.While then // NOTE: `shouldRewriteWhile` is not the same as `config.rewriteWhileLoops` @@ -501,16 +501,16 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C outerCtx.collectScopedSym(loopResult) outerCtx.collectScopedSym(isReturned) val loopEnd: Path = - Select(State.runtimeSymbol.asSimpleRef(N), Tree.Ident("LoopEnd"))(S(State.loopEndSymbol)) + Select(State.runtimeSymbol.asSimpleRef, Tree.Ident("LoopEnd"))(S(State.loopEndSymbol)) val blk = blockBuilder .define(FunDefn(N, f, tSym, PlainParamList(Nil) :: Nil, Begin(body, Return(loopEnd)))(configOverride = N, annotations = Nil)) .assign(loopResult, Call(f.asMemberRef(tSym), Nil ne_:: Nil)(true, true, false)) if summon[LoweringCtx].mayRet then blk - .assign(isReturned, Call(State.builtinOpsMap("!==").asSimpleRef(N), + .assign(isReturned, Call(State.builtinOpsMap("!==").asSimpleRef, (loopResult.asPath.asArg :: loopEnd.asArg :: Nil) ne_:: Nil)(true, false, false)) - .ifthen(isReturned.asSimpleRef(codegen.ErasedType.Primitive(codegen.PrimitiveType.Bool)), Case.Lit(Tree.BoolLit(true)), - Return(loopResult.asSimpleRef(N)), + .ifthen(isReturned.asSimpleRef, Case.Lit(Tree.BoolLit(true)), + Return(loopResult.asSimpleRef), N ) .rest(rest) From 0379ae99e048d11228401ff9f35e03ee055e1bae Mon Sep 17 00:00:00 2001 From: David Mak Date: Sat, 30 May 2026 11:00:02 +0800 Subject: [PATCH 07/48] codegen: Implement `erasedType` in symbols --- .../src/main/scala/hkmc2/codegen/Block.scala | 16 ++- .../scala/hkmc2/codegen/BlockSimplifier.scala | 4 +- .../hkmc2/codegen/BufferableTransform.scala | 25 ++--- .../scala/hkmc2/codegen/DeadParamElim.scala | 6 +- .../scala/hkmc2/codegen/EtaExpansion.scala | 4 +- .../FirstClassFunctionTransformer.scala | 4 +- .../scala/hkmc2/codegen/HandlerLowering.scala | 14 +-- .../src/main/scala/hkmc2/codegen/Lifter.scala | 44 ++++---- .../main/scala/hkmc2/codegen/Lowering.scala | 84 +++++++-------- .../codegen/ReflectionInstrumenter.scala | 4 +- .../hkmc2/codegen/StackSafeTransform.scala | 10 +- .../scala/hkmc2/codegen/SymbolRefresher.scala | 22 ++-- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 16 +-- .../scala/hkmc2/codegen/WorkerWrapper.scala | 2 +- .../hkmc2/codegen/deforest/Rewrite.scala | 20 ++-- .../scala/hkmc2/codegen/js/JSBuilder.scala | 6 +- .../scala/hkmc2/codegen/llir/Builder.scala | 10 +- .../scala/hkmc2/codegen/wasm/text/Ctx.scala | 2 +- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 61 +++++------ .../main/scala/hkmc2/invalml/InvalML.scala | 32 +++--- .../scala/hkmc2/semantics/Elaborator.scala | 101 +++++++++--------- .../main/scala/hkmc2/semantics/Importer.scala | 2 +- .../main/scala/hkmc2/semantics/Pattern.scala | 2 +- .../main/scala/hkmc2/semantics/Symbol.scala | 38 ++++--- .../hkmc2/semantics/ucs/Normalization.scala | 11 +- .../hkmc2/semantics/ucs/SplitElaborator.scala | 6 +- .../hkmc2/semantics/ucs/TermSynthesizer.scala | 4 +- .../scala/hkmc2/semantics/ups/Compiler.scala | 44 ++++---- .../hkmc2/semantics/ups/SplitCompiler.scala | 93 ++++++++-------- .../src/main/scala/hkmc2/syntax/Tree.scala | 2 +- .../test/mlscript/block-staging/Functions.mls | 2 +- .../test/mlscript/codegen/BlockPrinter.mls | 6 +- .../test/scala/hkmc2/JSBackendDiffMaker.scala | 2 +- 33 files changed, 356 insertions(+), 343 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 1f41794194..6b9f55bbf4 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -527,7 +527,7 @@ object HandleBlock: N, sym, PlainParamList(Param(FldFlags.empty, handler.resumeSym, N, Modulefulness.none) :: Nil) :: Nil, handler.body )(N, annotations = Nil) - val rSym = TempSymbol(N, "suspendRes") + val rSym = TempSymbol(N, erasedType = N, "suspendRes") FunDefn.withFreshSymbol( S(cls), handler.sym, @@ -680,7 +680,7 @@ final case class FunDefn( object FunDefn: def withFreshSymbol(owner: Opt[InnerSymbol], sym: BlockMemberSymbol, params: Ls[ParamList], body: Block)(configOverride: Opt[Config], annotations: Ls[Annot])(using State) = - val tSym = TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme)) + val tSym = TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme), erasedType = N) sym.tsym = S(tSym) FunDefn(owner, sym, tSym, params, body)(configOverride, annotations) @@ -706,7 +706,8 @@ object ValDefn: annotations: Ls[Annot], )(using State) : ValDefn = - ValDefn(tsym = TermSymbol(k, owner, Tree.Ident(sym.nme)), sym = sym, rhs = rhs)(configOverride, annotations) + // TODO(Derppening): We can probably use the erasedType from `rhs` once Path implements `HasErasedType` + ValDefn(tsym = TermSymbol(k, owner, Tree.Ident(sym.nme), erasedType = N), sym = sym, rhs = rhs)(configOverride, annotations) /* @@ -1058,12 +1059,9 @@ enum Value extends Path with HasErasedType with ProductWithExtraInfo: /** The [[`ErasedType`]] of this value. */ val erasedType: Opt[ErasedType] = this match - case SimpleRef(_, erasedType) => erasedType - case MemberRef(_, disamb: (ClassSymbol | ModuleOrObjectSymbol)) => S(ErasedType.AnyRef(false, disamb)) - case MemberRef(_, disamb: TypeAliasSymbol) => - // TODO(Derppening): Do we preserve the `TypeAliasSymbol` here? - disamb.irClsLikeDefn.flatMap(_.sym.asClsOrMod).map(ErasedType.AnyRef(false, _)) - case This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => S(ErasedType.AnyRef(false, clsOrMod)) + case SimpleRef(sym) => sym.erasedType + case MemberRef(_, disamb: (ClassSymbol | ModuleOrObjectSymbol | TypeAliasSymbol)) => disamb.erasedType + case This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => clsOrMod.erasedType case Lit(lit) => S(lit.erasedType) case _ => N diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala index 3292609efd..467a584b8c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala @@ -1179,7 +1179,7 @@ class BlockSimplifier def go(acc: Block => Block, args: List[(VarSymbol, Result)], mapping: Map[Symbol, Symbol]): Block = args match case Nil => - val resSym = TempSymbol(N, "inlinedVal") + val resSym = TempSymbol(N, erasedType = N, "inlinedVal") val copier = Copier(resSym, mapping) val newBlk = copier.applyBlock(blk) if extraArgss.isEmpty then @@ -1188,7 +1188,7 @@ class BlockSimplifier acc(Scoped(Set(resSym), newBlk( k(Call(resSym.asSimpleRef, extraArgss.ne_!)(c.isMlsFun, c.mayRaiseEffects, false))))) case (sym, value) :: argRest => - val newSym = VarSymbol(sym.id) + val newSym = VarSymbol(sym.id, erasedType = N) go(acc.assignScoped(newSym, value), argRest, mapping + (sym -> newSym)) go(blockBuilder, matchedArgs, Map.empty) case _ => super.applyResult(r)(k) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala index d5a102b1ed..a37407535e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala @@ -22,7 +22,7 @@ class BufferableTransform()(using Ctx, State, Raise): cls.bufferable.fold(super.applyDefn(defn)(k)): bufferable => val companionSym = ModuleOrObjectSymbol(DummyTypeDef(syntax.Mod), new Tree.Ident(cls.sym.nme)) val clsSizeSym = BlockMemberSymbol("size", Nil, false) - val clsSizeTermSym = TermSymbol(syntax.ImmutVal, S(companionSym), new Tree.Ident("size")) + val clsSizeTermSym = TermSymbol(syntax.ImmutVal, S(companionSym), new Tree.Ident("size"), erasedType = S(ErasedType.Primitive(PrimitiveType.Int))) val pubFieldMap: Map[Symbol, Symbol] = cls.publicFields.toMap val fields = cls.privateFields ++ cls.publicFields.map(_._2) val fieldMap: Map[Symbol, Int] = fields.zipWithIndex.toMap @@ -30,23 +30,20 @@ class BufferableTransform()(using Ctx, State, Raise): val allVars = params.flatMap(_.allParams).map(_.sym) val varMap = allVars .map: sym => - (sym, VarSymbol(sym.id)) + (sym, VarSymbol(sym.id, erasedType = N)) .toMap def mapParam(p: Param) = Param(p.flags, varMap(p.sym), p.sign, p.modulefulness) (params.map(pl => ParamList(pl.flags, pl.params.map(mapParam), pl.restParam.map(mapParam))), varMap.toMap) def mkFieldReplacer(buf: VarSymbol, baseIdx: VarSymbol, symMap: Map[Symbol, Symbol]) = - val baseIdxRef = baseIdx.asSimpleRef def getOffset(off: Int)(k: Path => Block): Block = - val idxSymbol = new TempSymbol(N, "idx") - val idxSymbolRef = idxSymbol.asSimpleRef - Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdxRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), - k(DynSelect(buf.asSimpleRef.selSN("buf"), idxSymbolRef, true)))) + val idxSymbol = new TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "idx") + Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), + k(DynSelect(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true)))) def assignToOffset(off: Int, r: Result, rst: Block) = - val idxSymbol = new TempSymbol(N, "idx") - val idxSymbolRef = idxSymbol.asSimpleRef - Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdxRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), - AssignDynField(buf.asSimpleRef.selSN("buf"), idxSymbolRef, true, r, applyBlock(rst)))) + val idxSymbol = new TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "idx") + Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(true, false, false), + AssignDynField(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true, r, applyBlock(rst)))) new BlockTransformer(SymbolSubst.Id): override def applyLocal(sym: Local): Local = symMap.getOrElse(sym, sym) override def applyBlock(b: Block): Block = b match @@ -76,11 +73,11 @@ class BufferableTransform()(using Ctx, State, Raise): k(res) case _ => super.applyPath(p)(k) def transformFunDefn(f: FunDefn, isCtor: Bool): FunDefn = - val buf = VarSymbol(new Tree.Ident("buf")) - val idx = VarSymbol(new Tree.Ident("idx")) + val buf = VarSymbol(new Tree.Ident("buf"), erasedType = N) + val idx = VarSymbol(new Tree.Ident("idx"), erasedType = S(ErasedType.Primitive(PrimitiveType.Int))) val (newParams, symMap) = mkSymbolReplacer(f.params) val blk = mkFieldReplacer(buf, idx, symMap).applyBlock(f.body) - FunDefn(f.owner, f.sym, TermSymbol(f.dSym.k, f.dSym.owner, f.dSym.id), PlainParamList( + FunDefn(f.owner, f.sym, TermSymbol(f.dSym.k, f.dSym.owner, f.dSym.id, erasedType = N), PlainParamList( Param(FldFlags.empty, buf, N, Modulefulness.none) :: Param(FldFlags.empty, idx, N, Modulefulness.none) :: Nil) :: newParams, if isCtor then Begin(blk, Return(idx.asSimpleRef)) else blk)(configOverride = f.configOverride, annotations = f.annotations) val fakeCtor = transformFunDefn(FunDefn.withFreshSymbol( diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala index 2b75982a87..5909a48e3b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala @@ -166,7 +166,7 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise): val name = instId.mkFunName + s"$$${f.nme}" f -> ( new BlockMemberSymbol(name, Nil, true), - new TermSymbol(Fun, N, Tree.Ident(name))) + new TermSymbol(Fun, N, Tree.Ident(name), erasedType = N)) .toMap) end mkNewPolyFnSyms @@ -334,12 +334,12 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise): case ParamList(flags, params, restParam) => val params2 = params.map: case p => - val newSym = new VarSymbol(Tree.Ident(p.sym.name)) + val newSym = new VarSymbol(Tree.Ident(p.sym.name), erasedType = N) refreshParamMap(p.sym) = newSym Param(p.flags, newSym, p.sign, p.modulefulness) val rest2 = restParam.map: case p => - val newSym = new VarSymbol(Tree.Ident(p.sym.name)) + val newSym = new VarSymbol(Tree.Ident(p.sym.name), erasedType = N) refreshParamMap(p.sym) = newSym Param(p.flags, newSym, p.sign, p.modulefulness) ParamList(flags, params2, rest2) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala index 7bcca2730a..b8658ec22e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala @@ -172,7 +172,7 @@ class EtaExpansionRewrite(val etaExpansionSolver: EtaExpansionSolver)(using Rais targetShape.drop(existingShape.size).zipWithIndex.map: case (count, idx) => val params = (0 until count).toList.map: i => - Param.simple(new VarSymbol(new Tree.Ident(s"eta$$$idx$$$i"))) + Param.simple(new VarSymbol(new Tree.Ident(s"eta$$$idx$$$i"), erasedType = N)) EtaParamList( ParamList(ParamListFlags.empty, params, N), params.map(p => Arg(N, p.sym.asSimpleRef)), @@ -194,7 +194,7 @@ class EtaExpansionRewrite(val etaExpansionSolver: EtaExpansionSolver)(using Rais Return( Call(fun, (argss ++ activeEtaArgss).ne_!)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall)) case _ => - val tmp = TempSymbol(N, "eta$res") + val tmp = TempSymbol(N, erasedType = N, "eta$res") Scoped( Set.single(tmp), Assign(tmp, res2, Return(etaCall(tmp.asPath).withLocOf(res2)))) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala index 0d924e89ce..876297b6dc 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala @@ -47,7 +47,7 @@ class FirstClassFunctionTransformer case s: TermSymbol if s.k is syntax.Fun => val params = getParamList(l).getOrElse(lastWords(s"Cannot get ${l.nme}'s parameter list.")) val clsDef = generateFCFunctionClass(ref, params) - val tmp = new TempSymbol(None) + val tmp = new TempSymbol(None, erasedType = S(ErasedType.AnyRef(rsc = false, clsDef.isym.asClsOrMod.get))) val cls = clsDef.sym.asMemberRef(clsDef.isym) Scoped( syms = Set(clsDef.sym, tmp), @@ -70,7 +70,7 @@ class FirstClassFunctionTransformer source = Diagnostic.Source.Compilation) PlainParamList(Nil) val clsDef = generateFCFunctionClass(sel, params) - val tmp = new TempSymbol(None) + val tmp = new TempSymbol(None, erasedType = S(ErasedType.AnyRef(rsc = false, clsDef.isym.asClsOrMod.get))) val cls = clsDef.sym.asMemberRef(clsDef.isym) Scoped( Set(clsDef.sym, tmp), diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index 11263f63b2..d391891302 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -116,7 +116,7 @@ type StackSafetyMap = collection.Map[FnOrCls, (Int, Block)] class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, Elaborator.State, Elaborator.Ctx): - private def freshTmp(dbgNme: Str = "tmp") = new TempSymbol(N, dbgNme) + private def freshTmp(erasedType: Opt[ErasedType], dbgNme: Str = "tmp") = new TempSymbol(N, erasedType, dbgNme) private def freshLabel(nme: Str) = new LabelSymbol(N, nme) private def rtThrowMsg(msg: Str) = Throw( @@ -136,7 +136,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case _ => N object StateTransition: - private val transitionSymbol = freshTmp("transition") + private val transitionSymbol = freshTmp(erasedType = N, "transition") def apply(uid: StateId) = Return(PureCall(transitionSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid))))) def unapply(blk: Block) = blk match @@ -145,7 +145,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case _ => N object Unwind: - private val unwindSymbol = freshTmp("unwind") + private val unwindSymbol = freshTmp(erasedType = N, "unwind") def apply(uid: StateId, loc: Value) = Return(PureCall(unwindSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid)), loc))) def unapply(blk: Block) = blk match @@ -537,7 +537,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, .flatMap: (sym, idx) => List(intLit(idx), Value.Lit(Tree.StrLit(sym.nme))) .map(_.asArg) - val debugInfoSym = freshTmp(s"$debugNme$$debugInfo") + val debugInfoSym = freshTmp(erasedType = N, s"$debugNme$$debugInfo") // TODO: properly support spread argument by calculating the correct length. val rtArgLists = intLit(fun.params.length) :: fun.params.flatMap: pl => intLit(pl.params.length) :: pl.params.map(p => p.sym.asSimpleRef) @@ -612,7 +612,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, return b val vars = if opt.debug then ctx.resumeInfo.currentLocals else computeRestoreList(parts) - val pcVar = freshTmp("pc") + val pcVar = freshTmp(erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "pc") val mainLoopLbl = freshLabel("main") val edges = computeEdges(parts) @@ -687,7 +687,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case (acc, f) => f(acc) Label(mainLoopLbl, true, matches, End()) - val getSavedTmp = freshTmp("saveOffset") + val getSavedTmp = freshTmp(erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "saveOffset") def getSaved(off: BigInt): (Block => Block, Path) = if off == 0 then return (id, DynSelect(paths.runtimePath.selSN("resumeArr"), paths.runtimePath.selSN("resumeIdx"), true)) @@ -736,7 +736,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, override def applyResult(r: Result)(k: Result => Block) = r match case EffectfulResult(r) => // Fallback case, this may lead to unnecessary assignments if it is assign-like - val l = freshTmp() + val l = freshTmp(erasedType = N) Scoped(Set(l), effectCheck(l, r, k(l.asSimpleRef))) case _ => super.applyResult(r)(k) topLevelTransform.applyBlock(b) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index 56a802ebc3..1ba3e31c5d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -375,7 +375,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): else val newSym = closureMap.get(l) match case None => - val newSym = TempSymbol(N, l.nme + "$here") + val newSym = TempSymbol(N, erasedType = N, l.nme + "$here") extraLocals.add(newSym) syms.addOne(FunSyms(l, d) -> newSym) // add to `syms`: this closure will be initialized in `applyBlock` closureMap.addOne(l -> newSym) // add to `closureMap`: `newSym` refers to the closure and can be used later @@ -395,7 +395,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): case cls: LiftedClass if !cls.isTrivial => val newSym = closureMap.get(l) match case None => - val newSym = TempSymbol(N, l.nme + "$here") + val newSym = TempSymbol(N, erasedType = N, l.nme + "$here") extraLocals.add(newSym) syms.addOne(FunSyms(l, d) -> newSym) closureMap.addOne(l -> newSym) @@ -547,9 +547,9 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val nme = sym.nme + "$" + id val ident = new Tree.Ident(nme) - val varSym = VarSymbol(ident) + val varSym = VarSymbol(ident, erasedType = N) val fldSym = BlockMemberSymbol(nme, Nil) - val tSym = TermSymbol(syntax.MutVal, S(clsSym), ident) + val tSym = TermSymbol(syntax.MutVal, S(clsSym), ident, erasedType = N) val p = Param(FldFlags.empty.copy(isVal = true), varSym, N, Modulefulness.none) varSym.decl = S(p) // * Currently this is only accessed to create the class' toString method @@ -832,11 +832,11 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): * A rewritten scope with a generic VarSymbol capture symbol. */ sealed trait GenericRewrittenScope[T] extends RewrittenScope[T]: - lazy val captureSym = VarSymbol(Tree.Ident(obj.nme + "$cap")) + lazy val captureSym = VarSymbol(Tree.Ident(obj.nme + "$cap"), erasedType = N) override lazy val capturePath = captureSym.asSimpleRef protected val liftedObjsOrdered: List[InnerSymbol] = node.liftedObjSyms.toList.sortBy(_.uid) protected val liftedObjsSyms: Map[InnerSymbol, VarSymbol] = liftedObjsOrdered.map: s => - s -> VarSymbol(Tree.Ident(s.nme + "$")) + s -> VarSymbol(Tree.Ident(s.nme + "$"), erasedType = N) .toMap override lazy val liftedObjsMap: Map[InnerSymbol, LocalPath] = liftedObjsSyms.map: case k -> v => k -> v.asLocalPath @@ -847,11 +847,11 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): * A rewritten scope with a TermSymbol capture symbol. */ sealed trait ClsLikeRewrittenScope[T](sym: InnerSymbol) extends RewrittenScope[T]: - lazy val captureSym = TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(obj.nme + "$cap")) + lazy val captureSym = TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(obj.nme + "$cap"), erasedType = N) override lazy val capturePath = captureSym.asSimpleRef protected val liftedObjsOrdered: List[InnerSymbol] = node.liftedObjSyms.toList.sortBy(_.uid) protected val liftedObjsSyms: Map[InnerSymbol, TermSymbol] = liftedObjsOrdered.map: s => - s -> TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(s.nme + "$")) + s -> TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(s.nme + "$"), erasedType = N) .toMap override lazy val liftedObjsMap: Map[InnerSymbol, LocalPath] = liftedObjsSyms.map: case k -> v => k -> LocalPath.privateSelfField(v) @@ -867,7 +867,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): LifterResult(liftedMtds, extras.flatten) // some helpers - private def dupParam(p: Param): Param = p.copy(sym = VarSymbol(Tree.Ident(p.sym.nme))) + private def dupParam(p: Param): Param = p.copy(sym = VarSymbol(Tree.Ident(p.sym.nme), erasedType = N)) private def dupParams(plist: List[Param]): List[Param] = plist.map(dupParam) private def dupParamList(plist: ParamList): ParamList = plist.copy(params = dupParams(plist.params), restParam = plist.restParam.map(dupParam)) @@ -914,7 +914,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): extends RewrittenScope[ClsLikeDefn](obj) with ClsLikeRewrittenScope[ClsLikeDefn](obj.cls.isym): - private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap")) + private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap"), erasedType = N) override lazy val capturePath: Path = captureSym.asSimpleRef override def rewriteImpl: LifterResult[ClsLikeDefn] = @@ -937,7 +937,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): extends RewrittenScope[ClsLikeBody](obj) with ClsLikeRewrittenScope[ClsLikeBody](obj.clsBody.isym): - private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.clsBody.isym), Tree.Ident(obj.nme + "$cap")) + private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.clsBody.isym), Tree.Ident(obj.nme + "$cap"), erasedType = N) override lazy val capturePath: Path = captureSym.asSimpleRef override def rewriteImpl: LifterResult[ClsLikeBody] = @@ -954,15 +954,15 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): class LiftedFunc(override val obj: ScopedObject.Func)(using ctx: LifterCtxNew) extends LiftedScope[FunDefn](obj) with GenericRewrittenScope[FunDefn]: private val passedSymsMap_ : Map[Local, VarSymbol] = passedSymsOrdered.map: s => - s -> VarSymbol(Tree.Ident(s.nme)) + s -> VarSymbol(Tree.Ident(s.nme), erasedType = N) .toMap private val capSymsMap_ : Map[ScopedInfo, VarSymbol] = capturesOrdered.map: i => val nme = data.getNode(i).obj.nme - i -> VarSymbol(Tree.Ident(nme + "$cap")) + i -> VarSymbol(Tree.Ident(nme + "$cap"), erasedType = N) .toMap private val defnSymsMap_ : Map[DefinitionSymbol[?], VarSymbol] = reqDefnsOrdered.sortBy(_.uid).map: i => val nme = data.getNode(i).obj.nme - i -> VarSymbol(Tree.Ident(nme + "$")) + i -> VarSymbol(Tree.Ident(nme + "$"), erasedType = N) .toMap override protected val passedSymsMap = passedSymsMap_.view.mapValues(_.asLocalPath).toMap @@ -1064,29 +1064,29 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): extends LiftedScope[ClsLikeDefn](obj) with ClsLikeRewrittenScope[ClsLikeDefn](obj.cls.isym): - private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap")) + private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap"), erasedType = N) override lazy val capturePath: Path = captureSym.asSimpleRef private val passedSymsMap_ : Map[Local, (vs: VarSymbol, ts: TermSymbol)] = passedSymsOrdered.map: s => s -> ( - VarSymbol(Tree.Ident(s.nme)), - TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(s.nme)) + VarSymbol(Tree.Ident(s.nme), erasedType = N), + TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(s.nme), erasedType = N) ) .toMap private val capSymsMap_ : Map[ScopedInfo, (vs: VarSymbol, ts: TermSymbol)] = capturesOrdered.map: i => val nme = data.getNode(i).obj.nme + "$cap" i -> ( - VarSymbol(Tree.Ident(nme)), - TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(nme)) + VarSymbol(Tree.Ident(nme), erasedType = N), + TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(nme), erasedType = N) ) .toMap private val defnSymsMap_ : Map[DefinitionSymbol[?], (vs: VarSymbol, ts: TermSymbol)] = reqDefnsOrdered.map: i => i -> ( - VarSymbol(Tree.Ident(i.nme + "$")), - TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(i.nme + "$")) + VarSymbol(Tree.Ident(i.nme + "$"), erasedType = N), + TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(i.nme + "$"), erasedType = N) ) .toMap @@ -1118,7 +1118,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): // Contains *all* parameters, and applies them all at once in a single `Instantiate` def mkFlattenedDefn: FunDefn = // Symbols for the aux parameter list - val auxSyms = auxParams.map(p => VarSymbol(Tree.Ident(p.sym.nme))) + val auxSyms = auxParams.map(p => VarSymbol(Tree.Ident(p.sym.nme), erasedType = N)) val auxParamListLocal = PlainParamList(auxSyms.map(Param.simple(_))) val dupedClsAuxParams = cls.auxParams.map(dupParamList(_)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 78c5b01b80..f80acbf5a0 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -43,8 +43,8 @@ class LoweringCtx( val map = initMap def collectScopedSym(s: Symbol) = definedSymsDuringLowering.add(s) def collectScopedSyms(s: Symbol*) = definedSymsDuringLowering.addAll(s) - def registerTempSymbol(trm: Option[Term], dbgNme: Str = "tmp")(using State) = - val tmp = new TempSymbol(trm, dbgNme) + def registerTempSymbol(trm: Option[Term], erasedType: Opt[ErasedType], dbgNme: Str = "tmp")(using State) = + val tmp = new TempSymbol(trm, erasedType, dbgNme) definedSymsDuringLowering.add(tmp) tmp def getCollectedSym: collection.Set[Symbol] = definedSymsDuringLowering @@ -413,7 +413,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): acc.reverse match case Nil => lowerRemainingCalls(fr, args, remainingArgss, isTailCall, loc)(k) case acc: NELs[Ls[Arg]] => - val tmp = loweringCtx.registerTempSymbol(N, "baseCall") + val tmp = loweringCtx.registerTempSymbol(N, erasedType = N, "baseCall") val call = Call(fr, acc)(isMlsFun, true, isTailCall).withLoc(loc) Assign(tmp, call, lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall, loc)(k)) case (_ :: _, Nil) => @@ -433,7 +433,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): remainingArgss match case Nil => k(call) case args :: remainingArgss => - val tmp = loweringCtx.registerTempSymbol(N, "callPrefix") + val tmp = loweringCtx.registerTempSymbol(N, erasedType = N, "callPrefix") Assign(tmp, call, lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall, loc)(k)) @@ -457,7 +457,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case (Nil, Nil) => k(buildInstantiate(acc.reverse)) case (Nil, args :: remainingArgss) => - val tmp = loweringCtx.registerTempSymbol(N, "baseInst") + val tmp = loweringCtx.registerTempSymbol(N, erasedType = N, "baseInst") Assign(tmp, buildInstantiate(acc.reverse), lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall = false, N)(k)) case (remainingParamss, Nil) => @@ -467,7 +467,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): remainingParamss match case Nil => buildInstantiate(accArgss) case ps :: rest => - val freshSyms = ps.params.map(p => new VarSymbol(new Tree.Ident(p.sym.nme))) + val freshSyms = ps.params.map(p => new VarSymbol(new Tree.Ident(p.sym.nme), erasedType = N)) softTODO(ps.restParam.isEmpty, "Eta expanding rest parameters in constructor definitions is not yet supported") val freshParams = (ps.params zip freshSyms).map((p, s) => Param(p.flags, s, N, p.modulefulness)) val freshParamList = ParamList(ps.flags, freshParams, N) @@ -496,7 +496,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Nil => k(buildInstantiate(as :: Nil)) case remainingArgss => - val tmp = loweringCtx.registerTempSymbol(N, "baseInst") + val tmp = loweringCtx.registerTempSymbol(N, erasedType = N, "baseInst") Assign(tmp, buildInstantiate(as :: Nil), lowerRemainingCalls(tmp.asSimpleRef, remainingArgss.head, remainingArgss.tail, isTailCall = false, N)(k)) else zipArgs(ctorParamLists, args, Nil) @@ -559,8 +559,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): if sym.binary then val t1 = new Tree.Ident("arg1") val t2 = new Tree.Ident("arg2") - val p1 = Param(FldFlags.empty, VarSymbol(t1), N, Modulefulness.none) - val p2 = Param(FldFlags.empty, VarSymbol(t2), N, Modulefulness.none) + val p1 = Param(FldFlags.empty, VarSymbol(t1, erasedType = N), N, Modulefulness.none) + val p2 = Param(FldFlags.empty, VarSymbol(t2, erasedType = N), N, Modulefulness.none) val ps = PlainParamList(p1 :: p2 :: Nil) val bod = st.App(ref, st.Tup(List(st.Ref(p1.sym)(t1, 666, N).resolve, st.Ref(p2.sym)(t2, 666, N).resolve)) (Tree.Tup(Nil // FIXME should not be required (using dummy value) @@ -575,7 +575,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): return k(Lambda(paramLists.head, bodyBlock)(Nil).withLocOf(ref)) if sym.unary then val t1 = new Tree.Ident("arg") - val p1 = Param(FldFlags.empty, VarSymbol(t1), N, Modulefulness.none) + val p1 = Param(FldFlags.empty, VarSymbol(t1, erasedType = N), N, Modulefulness.none) val ps = PlainParamList(p1 :: Nil) val bod = st.App(ref, st.Tup(List(st.Ref(p1.sym)(t1, 666, N).resolve)) (Tree.Tup(Nil // FIXME should not be required (using dummy value) @@ -659,8 +659,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): if !hasNonLocalContinueDispatch then term_nonTail(body)(r => Assign(result, r, Break(label))) else - val bodyResult = loweringCtx.registerTempSymbol(N, "labelBodyResult") - val isContinue = loweringCtx.registerTempSymbol(N, "labelContinueDispatch") + val bodyResult = loweringCtx.registerTempSymbol(N, erasedType = N, "labelBodyResult") + val isContinue = loweringCtx.registerTempSymbol(N, erasedType = N, "labelContinueDispatch") term_nonTail(body): r => Assign( bodyResult, @@ -736,7 +736,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): S(k(Value.Lit(negLit))), Unreachable("tail operation in branches"), ) else - val ts = loweringCtx.registerTempSymbol(N) + val ts = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool))) Match( ar1, (Case.Lit(posLit) -> term_nonTail(arg2)(Assign(ts, _, End()))) :: Nil, @@ -862,7 +862,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): S(Handler(td.sym, resumeSym, paramLists, bodyBlock)) }.collect{ case Some(v) => v } loweringCtx.collectScopedSym(lhs) - val resSym = loweringCtx.registerTempSymbol(S(t)) + val resSym = loweringCtx.registerTempSymbol(S(t), erasedType = N) subTerm(rhs): par => subTerms(as): asr => HandleBlock(lhs, resSym, par, asr, cls, handlers, @@ -957,7 +957,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Define(clsDef, term_nonTail(if mut then Mut(inner) else inner)(k)) case Try(sub, finallyDo) => - val l = loweringCtx.registerTempSymbol(S(sub)) + val l = loweringCtx.registerTempSymbol(S(sub), erasedType = N) TryBlock( subTerm_nonTail(sub)(p => Assign(l, p, End())), subTerm_nonTail(finallyDo)(_ => End()), @@ -1033,7 +1033,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): def quoteSplit(split: Split)(k: Result => Block)(using LoweringCtx): Block = split match case Split.Cons(Branch(scrutinee, pattern, continuation), tail) => quote(scrutinee): r1 => - val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N) + val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N, erasedType = N) blockBuilder.assign(l1, r1) .chain(b => quotePattern(pattern)(r2 => Assign(l2, r2, b))) .chain(b => quoteSplit(continuation)(r3 => Assign(l3, r3, b))) @@ -1042,19 +1042,19 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): .rest(setupTerm("Cons", (l4 :: l5 :: Nil).map(s => s.asSimpleRef))(k)) case Split.Let(sym, term, tail) => setupSymbol(sym): r1 => loweringCtx.collectScopedSym(sym) - val l1, l2, l3 = loweringCtx.registerTempSymbol(N) + val l1, l2, l3 = loweringCtx.registerTempSymbol(N, erasedType = N) blockBuilder.assign(l1, r1) .chain(b => setupTerm("Ref", l1.asSimpleRef :: Nil)(r => Assign(sym, r, b))) .chain(b => quote(term)(r2 => Assign(l2, r2, b))) .chain(b => quoteSplit(tail)(r3 => Assign(l3, r3, b))) .rest(setupTerm("Let", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef))(k)) case Split.Else(default) => quote(default): r => - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l, r, setupTerm("Else", l.asSimpleRef :: Nil)(k)) case Split.End => setupTerm("End", Nil)(k) case Split.LetSplit(sym, tail) => setupSymbol(sym): r1 => loweringCtx.collectScopedSym(sym) - val l1, l2, l3 = loweringCtx.registerTempSymbol(N) + val l1, l2, l3 = loweringCtx.registerTempSymbol(N, erasedType = N) blockBuilder.assign(l1, r1) .chain(b => Assign(sym, Value.Ref(l1), b)) .chain(b => quoteSplit(sym.body)(r2 => Assign(l2, r2, b))) @@ -1070,7 +1070,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Lit(lit) => setupTerm("Lit", Value.Lit(lit) :: Nil)(k) case Ref(sym) if Elaborator.binaryOps.contains(sym.nme) => // builtin symbols - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) setupTerm("Builtin", Value.Lit(Tree.StrLit(sym.nme)) :: Nil)(k) case Resolved(Ref(sym), disamb) => sym match @@ -1080,7 +1080,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Ref(sym) => k(sym.asPath) case SynthSel(Ref(sym: ModuleOrObjectSymbol), name) => // Local cross-stage references setupSymbol(sym): r1 => - val l1, l2 = loweringCtx.registerTempSymbol(N) + val l1, l2 = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.Lit(syntax.Tree.UnitLit(false)) :: Nil)(r2 => Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) )) @@ -1093,7 +1093,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ) else (t.toLoc, sym.toLoc) match case (S(Loc(_, _, Origin(base, _, _))), S(Loc(_, _, Origin(filename, _, _)))) => setupSymbol(sym): r1 => - val l1, l2 = loweringCtx.registerTempSymbol(N) + val l1, l2 = loweringCtx.registerTempSymbol(N, erasedType = N) val basePath = base.up val targetPath = filename val relPath = targetPath.relativeTo(basePath).map(_.toString).getOrElse(targetPath.toString) @@ -1109,8 +1109,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Lam(params, body) => def rec(ps: Ls[LocalSymbol & NamedSymbol], ds: Ls[Path])(k: Result => Block)(using LoweringCtx): Block = ps match case Nil => quote(body): r => - val l = loweringCtx.registerTempSymbol(N) - val arr = loweringCtx.registerTempSymbol(N, "arr") + val l = loweringCtx.registerTempSymbol(N, erasedType = N) + val arr = loweringCtx.registerTempSymbol(N, erasedType = N, "arr") Assign( arr, Tuple(mut = false, ds.reverse.map(_.asArg)), @@ -1118,24 +1118,24 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case sym :: rest => loweringCtx.collectScopedSym(sym) setupSymbol(sym): r => - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l, r, setupTerm("Ref", l.asSimpleRef :: Nil): r1 => Assign(sym, r1, rec(rest, l.asSimpleRef :: ds)(k))) rec(params.params.map(_.sym), Nil)(k) // TODO: restParam? case App(lhs, Tup(rhs)) => quote(lhs): r1 => def rec(es: Ls[Elem], xs: Ls[Path])(k: Result => Block): Block = es match case Nil => - val arrSym = loweringCtx.registerTempSymbol(N, "arr") + val arrSym = loweringCtx.registerTempSymbol(N, erasedType = N, "arr") Assign( arrSym, Tuple(mut = false, xs.reverse.map(_.asArg)), setupTerm("Tup", arrSym.asSimpleRef :: Nil): r2 => - val l1 = loweringCtx.registerTempSymbol(N) - val l2 = loweringCtx.registerTempSymbol(N) + val l1 = loweringCtx.registerTempSymbol(N, erasedType = N) + val l2 = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l1, r1, Assign(l2, r2, setupTerm("App", l1.asSimpleRef :: l2.asSimpleRef :: Nil)(k))) ) case Fld(_, t, _) :: rest => quote(t): r2 => - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l, r2, rec(rest, l.asSimpleRef :: xs)(k)) case Spd(eager, term) :: rest => fail: @@ -1149,8 +1149,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): require(sym2 is sym) loweringCtx.collectScopedSyms(sym) setupSymbol(sym){r1 => - val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N) - val arrSym = loweringCtx.registerTempSymbol(N, "arr") + val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N, erasedType = N) + val arrSym = loweringCtx.registerTempSymbol(N, erasedType = N, "arr") blockBuilder.assign(l1, r1) .chain(b => setupTerm("Ref", l1.asSimpleRef :: Nil)(r => Assign(sym, r, b))) .chain(b => quote(rhs)(r2 => Assign(l2, r2, b))) @@ -1161,7 +1161,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): .rest(setupTerm("Blk", arrSym.asSimpleRef :: l3.asSimpleRef :: Nil)(k)) } case IfLike(_, IfLikeForm.ReturningIf, split) => quoteSplit(split.getExpandedSplit): r => - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l, r, setupTerm("IfLike", setupQuotedKeyword("If") :: l.asSimpleRef :: Nil)(k)) case Unquoted(body) => term(body)(k) case _ => fail: @@ -1227,7 +1227,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): if fsr.isEmpty then Begin(b, k(asr.reverse)) else - val rcdSym = loweringCtx.registerTempSymbol(N, "rcd") + val rcdSym = loweringCtx.registerTempSymbol(N, erasedType = N, "rcd") Begin( b, Assign( @@ -1261,7 +1261,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val lamDef = FunDefn.withFreshSymbol(N, lamSym, params :: Nil, body)(configOverride = N, annotations = lam.annot) Define(lamDef, k(lamDef.asPath)) case r => - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l, r, k(l.asSimpleRef)) @@ -1410,7 +1410,7 @@ trait LoweringSelSanityChecks(using Config, TL, Raise, State) // * ^ We assume that resolved selections are well-behaved (will not yield undefined or debind a method) then super.setupSelection(prefix, nme, disamb)(k) else subTerm(prefix): p => - val selRes = loweringCtx.registerTempSymbol(N, "selRes") + val selRes = loweringCtx.registerTempSymbol(N, erasedType = N, "selRes") // * We are careful to access `x.f` before `x.f$__checkNotMethod` in case `x` is, eg, `undefined` and // * the access should throw an error like `TypeError: Cannot read property 'f' of undefined`. blockBuilder @@ -1467,12 +1467,12 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) go(paramLists.reverse, bod) def setupFunctionBody(params: ParamList, bod: Term, name: Option[Str])(using LoweringCtx): Block = inScopedBlock: - val enterMsgSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogEnterMsg") - val prevIndentLvlSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogPrevIndent") - val resSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogRes") - val retMsgSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogRetMsg") - val psInspectedSyms = params.params.map(p => loweringCtx.registerTempSymbol(N, dbgNme = s"traceLogParam_${p.sym.nme}") -> p.sym) - val resInspectedSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogResInspected") + val enterMsgSym = loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = "traceLogEnterMsg") + val prevIndentLvlSym = loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = "traceLogPrevIndent") + val resSym = loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = "traceLogRes") + val retMsgSym = loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = "traceLogRetMsg") + val psInspectedSyms = params.params.map(p => loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = s"traceLogParam_${p.sym.nme}") -> p.sym) + val resInspectedSym = loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = "traceLogResInspected") val psSymArgs = psInspectedSyms.zipWithIndex.foldRight[Ls[Arg]](Arg(N, Value.Lit(Tree.StrLit(")"))) :: Nil): @@ -1480,7 +1480,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) then Arg(N, s.asSimpleRef) :: acc else Arg(N, s.asSimpleRef) :: Arg(N, Value.Lit(Tree.StrLit(", "))) :: acc - val tmp1, tmp2, tmp3 = loweringCtx.registerTempSymbol(N) + val tmp1, tmp2, tmp3 = loweringCtx.registerTempSymbol(N, erasedType = N) assignStmts(psInspectedSyms.map: (pInspectedSym, pSym) => pInspectedSym -> pureCall(inspectFn, Arg(N, pSym.asSimpleRef) :: Nil) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala index 62c13f56ea..dbb1efac51 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala @@ -54,7 +54,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n def assign(res: Result, symName: Str = "tmp")(k: Path => Block): Block = // TODO: skip assignment if res: Path? - val sym = new TempSymbol(N, symName) + val sym = new TempSymbol(N, erasedType = N, symName) Scoped(Set(sym), Assign(sym, res, k(sym.asSimpleRef))) def tuple(elems: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = @@ -374,7 +374,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n val sym = f.owner.get.asThis.selSN(genSymName) // turn into fundefn - val dSym = TermSymbol(f.dSym.k, f.dSym.owner, Tree.Ident(f.sym.nme + "_instr")) + val dSym = TermSymbol(f.dSym.k, f.dSym.owner, Tree.Ident(f.sym.nme + "_instr"), erasedType = N) val argSyms = f.params.flatMap(_.params).map(_.sym) val newBody = Scoped(Set(argSyms*), transformFunDefn(f)(using new HashMap)(Return(_))) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala index bb6ab15766..4c6866abeb 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala @@ -65,7 +65,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S override def applyBlock(b: Block): Block = b match case Return(res) if usesStack(res) => - val tmp = TempSymbol(N, "res") + val tmp = TempSymbol(N, erasedType = N, "res") super.applyResult(res): res => Scoped(Set.single(tmp), extract(res, true, Return(_), tmp, curDepth)) // Optimization to avoid generation of unnecessary variables @@ -84,7 +84,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S override def applyResult(r: Result)(k: Result => Block): Block = if usesStack(r) then - val tmp = TempSymbol(N, "res") + val tmp = TempSymbol(N, erasedType = N, "res") Scoped(Set.single(tmp), extract(r, false, k, tmp, curDepth)) else super.applyResult(r)(k) @@ -138,9 +138,9 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S var usedDepth = false lazy val curDepth = usedDepth = true - TempSymbol(None, "curDepth") + TempSymbol(None, erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "curDepth") val newBody = transform(blk, curDepth) - val resSym = TempSymbol(None, "stackDelayRes") + val resSym = TempSymbol(None, erasedType = N, "stackDelayRes") val addStackSafeEffect = blk => blockBuilder .assignFieldN(runtimePath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(increment))) .staticif(usedDepth, _.assignScoped(curDepth, stackDepthPath)) @@ -160,4 +160,4 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S def rewriteFn(defn: FunDefn) = FunDefn(defn.owner, defn.sym, defn.dSym, defn.params, rewriteBlk(defn.body, L(defn.sym)))(defn.configOverride, defn.annotations) - def transformTopLevel(b: Block) = transform(b, TempSymbol(N), true) + def transformTopLevel(b: Block) = transform(b, TempSymbol(N, erasedType = N), true) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala index da41b40027..ab8744a111 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala @@ -27,7 +27,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends for s <- syms.toList.sortBy(_.uid) do assert(!mapping.isDefinedAt(s), s"already defined: $s") val newS = s match - case tmpSym: TempSymbol => new TempSymbol(N, tmpSym.nme) + case tmpSym: TempSymbol => new TempSymbol(N, erasedType = tmpSym.erasedType, tmpSym.nme) case bms: BlockMemberSymbol => val newBms = new BlockMemberSymbol(bms.nme, Nil, bms.nameIsMeaningful) newBms.tsym = bms.tsym.map: t => @@ -35,12 +35,12 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends existingMapping.get(o) match case Some(inner: InnerSymbol) => inner case _ => o - val nt = new TermSymbol(t.k, newOwner, t.id) + val nt = new TermSymbol(t.k, newOwner, t.id, erasedType = t.erasedType) mapping(t) = nt oldSyms.add(t) nt newBms - case varSym: VarSymbol => new VarSymbol(varSym.id) + case varSym: VarSymbol => new VarSymbol(varSym.id, erasedType = varSym.erasedType) case _ => lastWords(s"unexpected symbol kind: $s") mapping(s) = newS oldSyms.add(s) @@ -85,7 +85,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends val newBms = new BlockMemberSymbol(fun.sym.nme, fun.sym.trees, fun.sym.nameIsMeaningful) val newDsym = fun.sym.tsym.map: tsym => assert(tsym.owner.isEmpty) - new TermSymbol(tsym.k, N, tsym.id) + new TermSymbol(tsym.k, N, tsym.id, erasedType = tsym.erasedType) newBms.tsym = S(newDsym.get) // Keep the definition symbol in sync with the freshly-created member symbol. // Self-recursive references use the disambiguating TermSymbol, and later passes @@ -100,7 +100,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends def handleSingleParam(p: Param) = val Param(flags, sym, sign, modulefulness) = p oldParamSyms.append(sym) - val newSym = new VarSymbol(sym.id) + val newSym = new VarSymbol(sym.id, erasedType = sym.erasedType) assert(!mapping.isDefinedAt(sym)) mapping(sym) = newSym Param(flags, newSym, sign, modulefulness) @@ -117,7 +117,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends val (tsym2, sym2) = mapping.get(sym) match case None => val newBms = new BlockMemberSymbol(sym.nme, sym.trees, sym.nameIsMeaningful) - val newTsym = new TermSymbol(tsym.k, tsym.owner, tsym.id) + val newTsym = new TermSymbol(tsym.k, tsym.owner, tsym.id, erasedType = tsym.erasedType) newBms.tsym = S(newTsym) (newTsym, newBms) case S(bms: BlockMemberSymbol) => @@ -164,7 +164,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends def freshenParamList(pl: ParamList): ParamList = def handleParam(p: Param) = - val ns = new VarSymbol(p.sym.id) + val ns = new VarSymbol(p.sym.id, erasedType = p.sym.erasedType) assert(!mapping.isDefinedAt(p.sym)) mapping(p.sym) = ns hd += p.sym @@ -231,7 +231,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends fields: Ls[TermSymbol], ownerIsym: InnerSymbol ): Ls[TermSymbol] = fields.map: ts => assert(!mapping.isDefinedAt(ts)) - val nts = new TermSymbol(ts.k, S(ownerIsym), ts.id) + val nts = new TermSymbol(ts.k, S(ownerIsym), ts.id, erasedType = ts.erasedType) mapping(ts) = nts toRemoveSymbols.head += ts nts @@ -243,7 +243,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends assert(!mapping.isDefinedAt(bms)) assert(!mapping.isDefinedAt(ts)) val nbms = new BlockMemberSymbol(bms.nme, Nil, bms.nameIsMeaningful) - val nts = new TermSymbol(ts.k, S(ownerIsym), ts.id) + val nts = new TermSymbol(ts.k, S(ownerIsym), ts.id, erasedType = ts.erasedType) nbms.tsym = S(nts) mapping(bms) = nbms mapping(ts) = nts @@ -258,7 +258,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends assert(!mapping.isDefinedAt(m.sym)) assert(!mapping.isDefinedAt(m.dSym)) val newMsym = new BlockMemberSymbol(m.sym.nme, Nil, m.sym.nameIsMeaningful) - val newDsym = new TermSymbol(m.dSym.k, S(newIsym), m.dSym.id) + val newDsym = new TermSymbol(m.dSym.k, S(newIsym), m.dSym.id, erasedType = m.dSym.erasedType) newMsym.tsym = S(newDsym) mapping(m.sym) = newMsym mapping(m.dSym) = newDsym @@ -268,7 +268,7 @@ class SymbolRefresher(existingMapping: Map[Symbol, Symbol])(using State) extends val methodParamOlds = MutSet.empty[VarSymbol] val newParams = m.params.map: pl => def handleParam(p: Param) = - val ns = new VarSymbol(p.sym.id) + val ns = new VarSymbol(p.sym.id, erasedType = p.sym.erasedType) assert(!mapping.isDefinedAt(p.sym)) mapping(p.sym) = ns methodParamOlds += p.sym diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 3bcf8a89ba..9f61aba405 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -266,7 +266,7 @@ class TailRecOpt(using State, TL, Raise): val paramSyms = if funs.length === 1 then (getParamSyms(funs.head)) else - for i <- 0 until maxParamLen yield VarSymbol(Tree.Ident("param" + i)) + for i <- 0 until maxParamLen yield VarSymbol(Tree.Ident("param" + i), erasedType = N) .toList val paramSymsArr = ArrayBuffer.from(paramSyms) // Function -> param -> param symbol in the rewritten function @@ -286,9 +286,9 @@ class TailRecOpt(using State, TL, Raise): else BlockMemberSymbol(funs.map(_.sym.nme).mkString("_"), Nil, true) val dSym = if funs.size === 1 then funs.head.dSym - else TermSymbol(syntax.Fun, owner, Tree.Ident(bms.nme)) + else TermSymbol(syntax.Fun, owner, Tree.Ident(bms.nme), erasedType = N) val loopSym = LabelSymbol(N, "loopLabel") - val curIdSym = VarSymbol(Tree.Ident("id")) + val curIdSym = VarSymbol(Tree.Ident("id"), erasedType = S(ErasedType.Primitive(PrimitiveType.Int))) class FunRewriter(f: FunDefn) extends BlockTransformerShallow(SymbolSubst.Id): val params = getParamSyms(f) @@ -323,7 +323,7 @@ class TailRecOpt(using State, TL, Raise): .toSet val copiedParamSyms = copiedParams.map: x => - x -> VarSymbol(x.id) + x -> VarSymbol(x.id, erasedType = N) .toMap val subst = new SymbolSubst: @@ -358,7 +358,7 @@ class TailRecOpt(using State, TL, Raise): // We should thus assign the params to temporary symbols // if they are needed for a subsequent assignment. var assignedSyms: Map[VarSymbol, Lazy[TempSymbol]] = paramSyms.map: sym => - sym -> Lazy(TempSymbol(N, sym.nme + "_tmp")) // Use `Lazy` to avoid generating useless symbols + sym -> Lazy(TempSymbol(N, erasedType = sym.erasedType, sym.nme + "_tmp")) // Use `Lazy` to avoid generating useless symbols .toMap var requiredTmps: Set[(VarSymbol, TempSymbol)] = Set.empty @@ -422,7 +422,7 @@ class TailRecOpt(using State, TL, Raise): val paramList = ogParamList.params val restParam = ogParamList.restParam - val tupleSym = TempSymbol(N, "argList") + val tupleSym = TempSymbol(N, erasedType = N, "argList") // We can safely remove all of the symbols from this parameter list from `assignedSyms` at this stage, // because the RHS of every parameter will be computed when spreading them in the tuple, which happens @@ -438,7 +438,7 @@ class TailRecOpt(using State, TL, Raise): // If the rest param exists, append a slice val (initialBlk: (Block => Block), pathList: List[Path]) = if restParam.isDefined then - val sliceResSym = TempSymbol(N, "sliceRes") + val sliceResSym = TempSymbol(N, erasedType = N, "sliceRes") // runtime.Tuple.slice(tupleSym, paramList.length, 0) val sliceRes = Call( State.runtimeSymbol.asSimpleRef @@ -520,7 +520,7 @@ class TailRecOpt(using State, TL, Raise): // param list for the internal loop. We need a wrapper function that preserves the // original multi-param-list interface and delegates to the flattened internal loop. val loopBms = BlockMemberSymbol(bms.nme + "$tailrec", Nil, true) - val loopDSym = TermSymbol(syntax.Fun, owner, Tree.Ident(loopBms.nme)) + val loopDSym = TermSymbol(syntax.Fun, owner, Tree.Ident(loopBms.nme), erasedType = N) val internalLoopDefn = FunDefn( owner, loopBms, loopDSym, PlainParamList(params) :: Nil, diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala index e184e54240..13dbc19094 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala @@ -39,7 +39,7 @@ class WorkerWrapper body.size <= cfg.altSmallThreshold private def freshParam(param: Param, mapping: collection.mutable.Map[Symbol, Symbol]): Param = - val freshSym = new VarSymbol(param.sym.id) + val freshSym = new VarSymbol(param.sym.id, erasedType = param.sym.erasedType) mapping(param.sym) = freshSym Param(param.flags, freshSym, param.sign, param.modulefulness).withSignTypeOf(param) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala index 95a8d49cd6..faab92e516 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala @@ -117,7 +117,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val name = path.mkFunName + s"$$${f.nme}" f -> ( new BlockMemberSymbol(name, Nil, true), - new TermSymbol(Fun, N, Tree.Ident(name))) + new TermSymbol(Fun, N, Tree.Ident(name), erasedType = N)) .toMap) end mkNewPolyFnSyms @@ -153,7 +153,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val clsNme = selInfo.selectsFrom.ctorClsName fieldSym.getOrElseUpdate( selInfo.field, - new VarSymbol(Tree.Ident(s"${clsNme}_${selInfo.field.fieldName}"))) + new VarSymbol(Tree.Ident(s"${clsNme}_${selInfo.field.fieldName}"), erasedType = N)) ) // ctor dest branch function computations @@ -183,7 +183,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val owner = pre.res.matchScrutToCtxOfMatch(dest._1).collectFirst: case pre.InCtx.Cls(cls) => cls.isym new BlockMemberSymbol(branchFnNme, Nil, true) - -> new TermSymbol(Fun, owner, Tree.Ident(branchFnNme)) + -> new TermSymbol(Fun, owner, Tree.Ident(branchFnNme), erasedType = N) ) // compute the function parameters corresponding to ctor fields of branch funs branchFunParamFieldSyms.getOrElseUpdate( @@ -198,7 +198,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): completeArgs.map: selField => selsInfos.get(selField) match case Some(selId) => branchSelSyms(selId) - case None => VarSymbol(Tree.Ident(s"_${selField.fieldName}")) + case None => VarSymbol(Tree.Ident(s"_${selField.fieldName}"), erasedType = N) ) val (parents, _) = getParentLabelOrMatchesAndRestBefore(dest.exprId) @@ -219,7 +219,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): pre.res.matchScrutToCtxOfMatch(dtorId).collectFirst: case pre.InCtx.Cls(cls) => cls.isym new BlockMemberSymbol(restFunName, Nil, true) - -> new TermSymbol(Fun, owner, Tree.Ident(restFunName)) + -> new TermSymbol(Fun, owner, Tree.Ident(restFunName), erasedType = N) ) val (ps, restBeforeParent) = getParentLabelOrMatchesAndRestBefore(matchOrLabelId) restOriginalBodiesAndParentRest.getOrElseUpdate( @@ -405,7 +405,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val ctorInfo = solver.fusingCtorInfo(ctorDtorId) val clsNme = ctorInfo.ctor.ctorClsName ctorInfo.args.unzip._1.map: f => - new TempSymbol(N, s"${clsNme}_${f.fieldName}") + new TempSymbol(N, erasedType = N, s"${clsNme}_${f.fieldName}") end mkCtorFieldSyms solver.finalCtorDests.get(ctor.uid.concreteId) match @@ -437,7 +437,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): case ctor@CtorCall(_, args) if solver.finalCtorDests.isDefinedAt(ctor.uid.concreteId) => assert(args.isEmpty) val callBranchFun = mkCall(branchFunSyms(ctorWhichBranch(ctor.uid.concreteId)), Nil) - val lambdaSym = new TempSymbol(N, "deforest$lam") + val lambdaSym = new TempSymbol(N, erasedType = N, "deforest$lam") Scoped( Set.single(lambdaSym), Assign( @@ -512,7 +512,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): ParamList( pl.flags, pl.params.map: p => - val newSym = new VarSymbol(Tree.Ident(p.sym.name)) + val newSym = new VarSymbol(Tree.Ident(p.sym.name), erasedType = p.sym.erasedType) refreshParamMap(p.sym) = newSym Param(p.flags, newSym, p.sign, p.modulefulness), pl.restParam) @@ -531,7 +531,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val actualBody = Begin( new Rewriter(instId).applyBlock(ogBody), mkReturnCall(restFunSym, restFunArgs)) - val refreshedFvSymbols = dtorBranchFnFvs(branchId._1).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}"))) + val refreshedFvSymbols = dtorBranchFnFvs(branchId._1).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}"), erasedType = N)) val bodyWithCorrectSymbols = new RefreshSymbol(refreshedFvSymbols.toMap).applyBlock(actualBody) FunDefn(tSym.owner, bms, tSym, branchFunParamFieldSyms(branchId).asParamList :: refreshedFvSymbols.unzip._2.asParamList :: Nil, @@ -554,7 +554,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): mkReturnCall(parentFunSym, parentFunFvs)) case None => Begin(transformedOgBody, Return(Value.Lit(Tree.UnitLit(true)))) - val refreshedFvSymbols = restFnFvs(restFunId).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}"))) + val refreshedFvSymbols = restFnFvs(restFunId).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}"), erasedType = N)) val bodyWithCorrectSymbols = new RefreshSymbol(refreshedFvSymbols.toMap).applyBlock(actualBody) FunDefn(tsym.owner, bms, tsym, refreshedFvSymbols.unzip._2.asParamList :: Nil, bodyWithCorrectSymbols)(N, annotations = PrivateModifier :: Nil) end newRestFuns diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index 0cf4b49d72..bd6b9b887e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -76,7 +76,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: case _ => false private def getPrivateAccessorSymbol(ts: semantics.TermSymbol): semantics.TempSymbol = - privateAccessorSymbols.getOrElseUpdate(ts, semantics.TempSymbol(N, s"${ts.name}$$accessorSymbol")) + privateAccessorSymbols.getOrElseUpdate(ts, semantics.TempSymbol(N, erasedType = N, s"${ts.name}$$accessorSymbol")) private def selectPrivateField(ts: semantics.TermSymbol, loc: Opt[Loc])(using Raise, Scope): Opt[Document] = ts.owner.collect: @@ -459,7 +459,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: pubFlds.collect: case (_, sym) if sym.k is MutVal => sym -> TermSymbol( - syntax.LetBind, S(isym), Tree.Ident(sym.nme)) + syntax.LetBind, S(isym), Tree.Ident(sym.nme), erasedType = sym.erasedType) val allPrivFlds = privFlds ++ mutPubFields.map(_._2) val privDecls = allPrivFlds.map: fld => val nme = isym.privatesScope.allocateOrGetName(fld) @@ -1017,7 +1017,7 @@ trait JSBuilderArgNumSanityChecks(using TL, Config, Elaborator.State) override def checkSelections: Bool = instrument override def freezeDefinitions: Bool = instrument - val functionParamVarargSymbol = semantics.TempSymbol(N, "args") + val functionParamVarargSymbol = semantics.TempSymbol(N, erasedType = N, "args") override def setupFunction(name: Option[Str], params: ParamList, body: Block, isLambda: Bool)(using Raise, Scope): (Document, Document) = // * We used to instrument `fun f(x, y) = x + y` into something like diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 70bb4401d2..1ecda4531f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -104,17 +104,17 @@ final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): private def symMap(s: Local)(using ctx: Ctx)(using Raise, Scope) = ctx.findName(s) - private def newTemp = TempSymbol(N, "x") - private def newNamedTemp(name: Str) = TempSymbol(N, name) + private def newTemp = TempSymbol(N, erasedType = N, "x") + private def newNamedTemp(name: Str) = TempSymbol(N, erasedType = N, name) private def newNamedBlockMem(name: Str) = BlockMemberSymbol(name, Nil) - private def newNamed(name: Str) = VarSymbol(Tree.Ident(name)) + private def newNamed(name: Str) = VarSymbol(Tree.Ident(name), erasedType = N) private def newClassSym(name: Str) = ClassSymbol(Tree.TypeDef(hkmc2.syntax.Cls, Tree.Empty(), N), Tree.Ident(name)) private def newTupleSym(len: Int) = ClassSymbol(Tree.TypeDef(hkmc2.syntax.Cls, Tree.Empty(), N), Tree.Ident(s"Tuple$len")) - private def newVarSym(name: Str) = VarSymbol(Tree.Ident(name)) + private def newVarSym(name: Str) = VarSymbol(Tree.Ident(name), erasedType = N) private def newFunSym(name: Str) = BlockMemberSymbol(name, Nil) - private def newBuiltinSym(name: Str) = BuiltinSymbol(name, false, false, false, false) + private def newBuiltinSym(name: Str) = BuiltinSymbol(name, false, false, false, false, erasedType = N) private def builtinField(n: Int)(using Ctx) = summon[Ctx].builtinSym.fieldSym.getOrElseUpdate(n, newVarSym(s"field$n")) private def builtinApply(n: Int)(using Ctx) = summon[Ctx].builtinSym.applySym.getOrElseUpdate(n, newFunSym(s"apply$n")) private def builtinTuple(n: Int)(using Ctx) = summon[Ctx].builtinSym.tupleSym.getOrElseUpdate(n, newTupleSym(n)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala index a34296df20..ee4da70dbd 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala @@ -694,7 +694,7 @@ class Ctx(using State) extends ToWat: val id = SymIdx(name) memories = memories + (id -> - Import(module, name, ExternType.Mem(MemType(Limits(minPages)), sym = TempSymbol(N, name)))) + Import(module, name, ExternType.Mem(MemType(Limits(minPages)), sym = TempSymbol(N, erasedType = N, name)))) cachedMemoryImport(key) = SymIdx(name) end match end ensureMemoryImport diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 69e93b7f3f..8a87f4a744 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -52,13 +52,13 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private val baseObjectSym: BlockMemberSymbol = BlockMemberSymbol("Object", Nil) /** Synthetic field symbol for the object-header pointer to a class's shared RTTI object. */ - private val typeInfoFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$typeinfo")) + private val typeInfoFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$typeinfo"), erasedType = N) /** Synthetic field symbol for the runtime class tag stored in RTTI. */ - private val tagFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$tag")) + private val tagFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$tag"), erasedType = S(ErasedType.Primitive(PrimitiveType.Int))) /** Synthetic field symbol for the direct-parent RTTI reference used by runtime subtype checks. */ - private val parentFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$parent")) + private val parentFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$parent"), erasedType = N) private case class StringLitInfo(offset: Int, byteLen: Int, watBytes: Str) private val stringLits: LinkedHashMap[Str, StringLitInfo] = LinkedHashMap.empty @@ -130,9 +130,9 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: scrutTypeInfo: Expr, targetTypeInfo: Expr, )(using Ctx, FunctionCtx, Raise): Expr = - val currentTmp = mkTempLocal("currentTypeInfo") - val targetTmp = mkTempLocal("targetTypeInfo") - val resultTmp = mkTempLocal("typeInfoMatch") + val currentTmp = mkTempLocal("currentTypeInfo", erasedType = N) + val targetTmp = mkTempLocal("targetTypeInfo", erasedType = N) + val resultTmp = mkTempLocal("typeInfoMatch", erasedType = S(ErasedType.Primitive(PrimitiveType.Bool))) funcCtx.withLabel(LabelSymbol(N, "typeInfo"), hasContinueLabel = true): case LabelTarget(breakLabel, S(continueLabel)) => blockInstr( @@ -473,7 +473,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: params: Seq[Local -> SymIdx], )(using Ctx, Raise): TypeIdx = ctx.addType(TypeInfo( - sym = TempSymbol(N, defn.sym.nme), + sym = TempSymbol(N, erasedType = N, defn.sym.nme), FunctionType( params = params.map(p => WasmParam(p._2, RefType.anyref)), results = Seq(Result(RefType.anyref)), @@ -494,7 +494,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private def virtualMethodFuncType(arity: Int)(using Ctx, Raise): TypeIdx = ctx.getOrCreateWasmIntrinsicType(WasmIntrinsicType.VirtualMethod(arity)): ctx.addType(TypeInfo( - sym = TempSymbol(N, s"virtual$arity"), + sym = TempSymbol(N, erasedType = N, s"virtual$arity"), compType = virtualMethodSignature(arity), objectTag = N, )) @@ -651,7 +651,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case func: SessionFunc => // If the function symbol comes from a class or module, generate a TempSymbol to avoid symbol collision with // the class/module itself - val funcTySym = TempSymbol(N, func.sym.nme) + val funcTySym = TempSymbol(N, erasedType = N, func.sym.nme) val typeIdx = ctx.addType(TypeInfo(sym = funcTySym, wrapId = func.wrapId, compType = func.funcType, objectTag = N)) ctx.addFunctionImport(WasmImport( @@ -687,7 +687,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: typeInfoTypeIdxs(cls.sym) = typeInfoTypeIdx val globalExtern = ExternType.Global( GlobalType(RefType(typeInfoTypeIdx, nullable = false), mutable = false), - TempSymbol(N, cls.sym.nme), + TempSymbol(N, erasedType = N, cls.sym.nme), wrapId = N -> S("typeinfo"), ) val globalIdx = ctx.addGlobalImport(WasmImport( @@ -717,7 +717,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val newSlotFields = currentVirtualMethods.zipWithIndex.drop(parentVirtualMethodCount).map: (methodSym, slot) => val methodDefn = defn.methods.find(_.sym == methodSym).get val arity = 1 + methodDefn.params.headOption.fold(0)(_.params.size) - val fieldSym = TermSymbol(syntax.MutVal, owner = N, Ident(s"slot$slot")) + val fieldSym = TermSymbol(syntax.MutVal, owner = N, Ident(s"slot$slot"), erasedType = N) fieldSym -> Field( RefType(virtualMethodFuncType(arity), nullable = true), mutable = true, @@ -725,7 +725,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ) val typeInfoType = ctx.addType(TypeInfo( - sym = TempSymbol(N, defn.sym.nme), + sym = TempSymbol(N, erasedType = N, defn.sym.nme), compType = StructType(fields = inheritedFields ++ newSlotFields, parents = Seq(parentTypeInfoIdx)), objectTag = N, wrapId = N -> S("typeinfo"), @@ -796,7 +796,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: /** Gets (and caches) the exception tag used for MLX `throw`. */ private def exnTagIdx(using Ctx, Raise): TagIdx = - val sym = TempSymbol(N, "mlx_exn") + val sym = TempSymbol(N, erasedType = N, "mlx_exn") ctx.getOrCreateWasmIntrinsicTag( "mlx_exn", ctx.addTag(TagInfo( @@ -846,7 +846,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: module = ExternIntrinsics.SystemModule, name = ExternIntrinsics.StringFromUtf16ImportName, ): - val importTySym = TempSymbol(N, ExternIntrinsics.StringFromUtf16ImportName) + val importTySym = TempSymbol(N, erasedType = N, ExternIntrinsics.StringFromUtf16ImportName) val importTy = ctx.addType(TypeInfo( sym = importTySym, compType = FunctionType( @@ -880,8 +880,8 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: /** Allocates a fresh temp local (typed `anyref`) and returns its `LocalIdx`. */ - private def mkTempLocal(base: Str)(using Ctx, FunctionCtx, Raise): LocalIdx = - funcCtx.addLocal(TempSymbol(N, base)) + private def mkTempLocal(base: Str, erasedType: Opt[ErasedType])(using Ctx, FunctionCtx, Raise): LocalIdx = + funcCtx.addLocal(TempSymbol(N, erasedType, base)) /** Binds constructor self (`thisSym`) to the Wasm local name `this` in the current function context. */ @@ -1005,7 +1005,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val elemType = RefType.anyref val mutArrayType = tupleArrayType(true) val immArrayType = tupleArrayType(false) - val tupleTmp = mkTempLocal("tuple") + val tupleTmp = mkTempLocal("tuple", erasedType = N) val tupleIsMutable = ref.test(local.tee(tupleTmp, tupleExpr), RefType(mutArrayType, nullable = true)) val tupleValue = local.get(tupleTmp, RefType.anyref) val mutableBranch = @@ -1050,7 +1050,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: extraInfo = S(errExtra), ) - val idxTmp = mkTempLocal("idx") + val idxTmp = mkTempLocal("idx", erasedType = S(ErasedType.Primitive(PrimitiveType.Int31))) tupleRef => val storeIdx = local.set(idxTmp, ref.i31(idxI32)) @@ -1158,7 +1158,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ctx.getVirtualTable(ownerCls).flatMap(_.virtualMethodSlots.get(methodSym)) match case S(slot) => val ownerTypeInfoIdx = typeInfoTypeIdxs(ownerCls) - val receiverTmp = mkTempLocal("receiver") + val receiverTmp = mkTempLocal("receiver", erasedType = N) val receiverExpr = local.set(receiverTmp, result(qual)) val receiverRef = local.get(receiverTmp, RefType.anyref) val ownerTypeInfoRef = ref.cast( @@ -1500,7 +1500,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ctx.addFunctionImport(WasmImport( ExternIntrinsics.SystemModule, name, - ExternType.Func(TypeUse(typeIdx), TempSymbol(N, name), wrapId = N -> N), + ExternType.Func(TypeUse(typeIdx), TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), name), wrapId = N -> N), )) /** Creates the intrinsic definition for `name`. @@ -1515,7 +1515,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private def declareIntrinsicType(name: Str)(using Ctx, Raise): TypeIdx = ctx.addType(TypeInfo( - sym = TempSymbol(N, name), + sym = TempSymbol(N, erasedType = N, name), compType = FunctionType( params = intrinsicParamSuffixes(name).map(nme => WasmParam(SymIdx(nme), RefType.anyref)), results = Seq(Result(RefType.anyref)), @@ -1558,7 +1558,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: )(using Ctx, Raise): FuncIdx = val funcTy = declareIntrinsicType(name) val funcInfo = FuncInfo( - sym = TempSymbol(N, name), + sym = TempSymbol(N, erasedType = N, name), typeUse = TypeUse(funcTy), params = params, locals = Seq.empty, @@ -1608,9 +1608,10 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: /** Creates parameters for an intrinsic. */ + // TODO(Derppening): WTF? Remove `name` and add erasedType to `params` private def mkIntrinsicParams(name: Str, suffixes: Seq[Str]): Seq[TempSymbol -> SymIdx] = suffixes.map: suffix => - val sym = TempSymbol(N, suffix) + val sym = TempSymbol(N, erasedType = N, suffix) sym -> SymIdx(suffix) /** Loads the local `name` as an `anyref`. @@ -1804,7 +1805,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: if sym.nameIsMeaningful then val funcTy = ctx.addType( TypeInfo( - sym = TempSymbol(N, sym.nme), + sym = TempSymbol(N, erasedType = N, sym.nme), compType = FunctionType( params = fnCtx.params.map(p => WasmParam(p._2, RefType.anyref)), results = Seq.fill(bodyWat.resultTypes.length)(Result(RefType.anyref)), @@ -2100,11 +2101,11 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case Match(scrut, arms, dflt, rst) => val tailMode = rst.isInstanceOf[End] val matchResLocal = - if tailMode then S(mkTempLocal("matchRes")) + if tailMode then S(mkTempLocal("matchRes", erasedType = N)) else N val scrutLocalResult = scrut match case _: (Value.RefLike | Value.Lit) => N - case _ => S(mkTempLocal("scrut")) + case _ => S(mkTempLocal("scrut", erasedType = N)) val scrutInitExpr = scrutLocalResult.map: scrutLocal => local.set(scrutLocal, result(scrut)) @@ -2386,7 +2387,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val entrySym = BlockMemberSymbol("entry", Nil) val entryFnTy = ctx.addType(TypeInfo( - sym = TempSymbol(N, entrySym.nme), + sym = TempSymbol(N, erasedType = N, entrySym.nme), FunctionType(params = Seq.empty, results = Seq(Result(RefType.anyref))), objectTag = N, )) @@ -2407,13 +2408,13 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: offset = i32.const(lit.offset), bytes = Seq(lit.watBytes), memuse = N, - sym = TempSymbol(N, s.take(WatBuilder.StringConstantIdentMaxLength)), + sym = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), s.take(WatBuilder.StringConstantIdentMaxLength)), )) val initActions = ctx.getSingletonInitActions if initActions.nonEmpty then val initTy = ctx.addType(TypeInfo( - sym = TempSymbol(N, "start"), + sym = TempSymbol(N, erasedType = N, "start"), compType = FunctionType(params = Seq.empty, results = Seq.empty), objectTag = N, )) @@ -2423,7 +2424,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: resultTypes = Seq.empty, ) val initFn = ctx.addFunc(FuncInfo( - sym = TempSymbol(N, "start"), + sym = TempSymbol(N, erasedType = N, "start"), typeUse = TypeUse(initTy), params = Seq.empty, resultTypes = Seq.empty, diff --git a/hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala b/hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala index 1086f70968..5dd3dc2d13 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala @@ -141,7 +141,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): case _ => error(msg"Effect cannot be polymorphic." -> ty.toLoc :: Nil) }).getOrElse(Bot)) case f @ Term.Forall(tvs, outer, body) => - val outVar = freshOuter(outer.getOrElse(new TempSymbol(S(f), "outer")))(using ctx) + val outVar = freshOuter(outer.getOrElse(new TempSymbol(S(f), erasedType = N, "outer")))(using ctx) val nestCtx = ctx.nestWithOuter(outVar) outer.foreach(sym => nestCtx += sym -> outVar) given InvalCtx = nestCtx @@ -203,7 +203,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): typeAndSubstType(ty, pol = true)(using Map.empty) private def instantiate(ty: PolyType)(using ctx: InvalCtx): GeneralType = - ty.instantiate(infVarState.nextUid, freshEnv(new TempSymbol(N, "env")), ctx.lvl)(tl) + ty.instantiate(infVarState.nextUid, freshEnv(new TempSymbol(N, erasedType = N, "env")), ctx.lvl)(tl) private def extrude(ty: GeneralType)(using ctx: InvalCtx, pol: Bool, cctx: CCtx): GeneralType = ty match case ty: Type => solver.extrude(ty)(using ctx.lvl, pol, HashMap.empty) @@ -239,7 +239,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): nestCtx &= (sym, tv, sk) (tv, sk) val (bodyTy, ctxTy, eff) = typeCode(body) - val res = freshVar(new TempSymbol(S(f), "ctx"))(using ctx) + val res = freshVar(new TempSymbol(S(f), erasedType = N, "ctx"))(using ctx) constrain(ctxTy, bds.foldLeft[Type](res)((res, bd) => res | bd._2)) (FunType(bds.map(_._1), bodyTy, Bot), res, eff) case app @ Term.App(lhs, Term.Tup(rhs)) => @@ -248,7 +248,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): case (res, p: Fld) => val (ty, ctx, eff) = typeCode(p.term) (ty :: res._1, res._2 | ctx, res._3 | eff) - val resTy = freshVar(new TempSymbol(S(app), "app")) + val resTy = freshVar(new TempSymbol(S(app), erasedType = N, "app")) constrain(lhsTy, FunType(rhsTy.reverse, resTy, Bot)) // TODO: right (resTy, lhsCtx | rhsCtx, lhsEff | rhsEff) case sel @ Term.SynthSel(Term.Ref(_: TopLevelSymbol), _) if sel.symbol.isDefined => @@ -256,8 +256,8 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): (tryMkMono(opTy, sel), Bot, eff) case unq @ Term.Unquoted(body) => val (ty, eff) = typeCheck(body) - val tv = freshVar(new TempSymbol(S(unq), "cde")) - val cr = freshVar(new TempSymbol(S(unq), "ctx")) + val tv = freshVar(new TempSymbol(S(unq), erasedType = N, "cde")) + val cr = freshVar(new TempSymbol(S(unq), erasedType = N, "ctx")) constrain(tryMkMono(ty, body), InvalCtx.codeTy(tv, cr)) (tv, cr, eff) case blk @ Term.Blk(LetDecl(sym, _) :: DefineVar(sym2, rhs) :: Nil, body) @@ -268,7 +268,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): val sk = freshSkolem(sym) nestCtx &= (sym, rhsTy, sk) val (bodyTy, bodyCtx, bodyEff) = typeCode(body) - val res = freshVar(new TempSymbol(S(blk), "ctx"))(using ctx) + val res = freshVar(new TempSymbol(S(blk), erasedType = N, "ctx"))(using ctx) constrain(bodyCtx, sk | res) (bodyTy, rhsCtx | res, rhsEff | bodyEff) case Term.IfLike(_, IfLikeForm.ReturningIf, SimpleSplit.IfThenElse(cond, cons, alts)) => @@ -288,7 +288,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): ascribe(lam, sigTy) () case N => - val outer = freshOuter(new TempSymbol(S(lam), "outer"))(using ctx) + val outer = freshOuter(new TempSymbol(S(lam), erasedType = N, "outer"))(using ctx) given InvalCtx = ctx.nestWithOuter(outer) val funTyV = freshVar(sym) ctx += sym -> funTyV // for recursive functions @@ -409,7 +409,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): case S(sym) => val (clsTy, tv, emptyTy) = sym.defn.map(sym -> _) match case S((sym, cls)) => - (ClassLikeType(sym, cls.tparams.map(_ => freshWildcard(sym))), (freshVar(new TempSymbol(S(scrutinee), "scrut"))), ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty))) + (ClassLikeType(sym, cls.tparams.map(_ => freshWildcard(sym))), (freshVar(new TempSymbol(S(scrutinee), erasedType = N, "scrut"))), ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty))) case _ => error(msg"Cannot match ${scrutinee.toString} as ${sym.toString}" -> split.toLoc :: Nil) (Bot, Bot, Bot) @@ -532,8 +532,8 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): val (ty, eff) = typeCheck(f.term) Left(ty) :: Right(eff) :: Nil .partitionMap(x => x) - val effVar = freshVar(new TempSymbol(S(t), "eff")) - val retVar = freshVar(new TempSymbol(S(t), "app")) + val effVar = freshVar(new TempSymbol(S(t), erasedType = N, "eff")) + val retVar = freshVar(new TempSymbol(S(t), erasedType = N, "app")) constrain(tryMkMono(funTy, t), FunType(argTy.map((tryMkMono(_, t))), retVar, effVar)) (retVar, argEff.foldLeft[Type](effVar | lhsEff)((res, e) => res | e)) @@ -736,25 +736,25 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): given InvalCtx = nestCtx nestCtx += sym -> InvalCtx.regionTy(sk) val (res, eff) = typeCheck(body) - val tv = freshVar(new TempSymbol(S(reg), "eff"))(using ctx) + val tv = freshVar(new TempSymbol(S(reg), erasedType = N, "eff"))(using ctx) constrain(eff, tv | sk) (extrude(res)(using ctx, true), tv) case Term.RegRef(reg, value) => val (regTy, regEff) = typeCheck(reg) val (valTy, valEff) = typeCheck(value) - val sk = freshVar(new TempSymbol(S(reg), "reg")) + val sk = freshVar(new TempSymbol(S(reg), erasedType = N, "reg")) constrain(tryMkMono(regTy, reg), InvalCtx.regionTy(sk)) (InvalCtx.refTy(tryMkMono(valTy, value), sk), sk | (regEff | valEff)) case Term.SetRef(lhs, rhs) => val (lhsTy, lhsEff) = typeCheck(lhs) val (rhsTy, rhsEff) = typeCheck(rhs) - val sk = freshVar(new TempSymbol(S(lhs), "reg")) + val sk = freshVar(new TempSymbol(S(lhs), erasedType = N, "reg")) constrain(tryMkMono(lhsTy, lhs), InvalCtx.refTy(tryMkMono(rhsTy, rhs), sk)) (tryMkMono(rhsTy, rhs), sk | (lhsEff | rhsEff)) case Term.Deref(ref) => val (refTy, refEff) = typeCheck(ref) - val sk = freshVar(new TempSymbol(S(ref), "reg")) - val ctnt = freshVar(new TempSymbol(S(ref), "ref")) + val sk = freshVar(new TempSymbol(S(ref), erasedType = N, "reg")) + val ctnt = freshVar(new TempSymbol(S(ref), erasedType = N, "ref")) constrain(tryMkMono(refTy, ref), InvalCtx.refTy(ctnt, sk)) (ctnt, sk | refEff) case Term.Quoted(body) => diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala index a98bbc152b..c4aeaa4fcf 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala @@ -17,6 +17,7 @@ import hkmc2.Message.MessageContext import Keyword.{`let`, `set`} import hkmc2.utils.Scope +import codegen.ErasedType object Elaborator: @@ -333,15 +334,15 @@ object Elaborator: val tupleSymbol = ModuleOrObjectSymbol(DummyTypeDef(syntax.Mod), Ident("Tuple")) val strSymbol = ModuleOrObjectSymbol(DummyTypeDef(syntax.Mod), Ident("Str")) // In JavaScript, `import` can be used for getting current file path, as `import.meta` - val importSymbol = new VarSymbol(Ident("import")) + val importSymbol = new VarSymbol(Ident("import"), erasedType = N) val noSymbol = NoSymbol() - val runtimeSymbol = TempSymbol(N, "runtime") - val definitionMetadataSymbol = TempSymbol(N, "definitionMetadata") - val prettyPrintSymbol = TempSymbol(N, "prettyPrint") - val termSymbol = TempSymbol(N, "Term") - val blockSymbol = TempSymbol(N, "Block") - val optionSymbol = TempSymbol(N, "option") - val wasmSymbol = TempSymbol(N, "wasm") + val runtimeSymbol = TempSymbol(N, erasedType = N, "runtime") + val definitionMetadataSymbol = TempSymbol(N, erasedType = N, "definitionMetadata") + val prettyPrintSymbol = TempSymbol(N, erasedType = N, "prettyPrint") + val termSymbol = TempSymbol(N, erasedType = N, "Term") + val blockSymbol = TempSymbol(N, erasedType = N, "Block") + val optionSymbol = TempSymbol(N, erasedType = N, "option") + val wasmSymbol = TempSymbol(N, erasedType = N, "wasm") val nonLocalRetHandlerTrm = val id = new Ident("NonLocalReturn") val sym = ClassSymbol(DummyTypeDef(syntax.Cls), id) @@ -352,7 +353,7 @@ object Elaborator: val nonLocalRet = val id = new Ident("ret") BlockMemberSymbol(id.name, Nil, true) - val unreachableSymbol = TermSymbol(syntax.ImmutVal, N, new Ident("unreachable")) + val unreachableSymbol = TermSymbol(syntax.ImmutVal, N, new Ident("unreachable"), erasedType = N) val tupleGetSymbol = createFunSymbolInMod("get", "xs" :: "i" :: Nil, tupleSymbol) val tupleSliceSymbol = createFunSymbolInMod("slice", "xs" :: "i" :: "j" :: Nil, tupleSymbol) val tupleLazySliceSymbol = createFunSymbolInMod("lazySlice", "xs" :: "i" :: "j" :: Nil, tupleSymbol) @@ -364,11 +365,11 @@ object Elaborator: val id = new Ident("MatchSuccess") val td = TypeDef(syntax.Cls, App(id, Tup(Ident("output") :: Ident("bindings") :: Nil)), N) val cs = ClassSymbol(td, id) - val ts = TermSymbol(syntax.Fun, N, id) + val ts = TermSymbol(syntax.Fun, N, id, erasedType = N) val flag = FldFlags.empty.copy(isVal = true) val ps = PlainParamList( - Param(flag, VarSymbol(Ident("output")), N, Modulefulness(N)(false)) :: - Param(flag, VarSymbol(Ident("bindings")), N, Modulefulness(N)(false)) :: + Param(flag, VarSymbol(Ident("output"), erasedType = N), N, Modulefulness(N)(false)) :: + Param(flag, VarSymbol(Ident("bindings"), erasedType = N), N, Modulefulness(N)(false)) :: Nil) val ctsym = ClassCtorSymbol(Fun, S(cs), cs.id) cs.defn = S(ClassDef.Parameterized(N, syntax.Cls, cs, BlockMemberSymbol(cs.name, Nil), S(ctsym), @@ -378,9 +379,9 @@ object Elaborator: val id = new Ident("MatchFailure") val td = DummyTypeDef(syntax.Cls) val cs = ClassSymbol(td, id) - val ts = TermSymbol(syntax.Fun, N, id) + val ts = TermSymbol(syntax.Fun, N, id, erasedType = N) val flag = FldFlags.empty.copy(isVal = true) - val ps = PlainParamList(Param(flag, VarSymbol(Ident("errors")), N, Modulefulness(N)(false)) :: Nil) + val ps = PlainParamList(Param(flag, VarSymbol(Ident("errors"), erasedType = N), N, Modulefulness(N)(false)) :: Nil) val ctsym = ClassCtorSymbol(Fun, S(cs), cs.id) cs.defn = S(ClassDef.Parameterized(N, syntax.Cls, cs, BlockMemberSymbol(cs.name, td :: Nil), S(ctsym), Nil, ps, Nil, N, ObjBody(Blk(Nil, Term.Lit(UnitLit(false)))), N, Nil)) @@ -391,7 +392,8 @@ object Elaborator: binary = binaryOps(op), unary = unaryOps(op), nullary = false, - functionLike = anyOps(op)) + functionLike = anyOps(op), + erasedType = N) .toMap baseBuiltins ++ aliasOps.map: case (alias, base) => alias -> baseBuiltins(base) @@ -409,9 +411,9 @@ object Elaborator: // ^ we do not display the uid by default to avoid polluting diff-test outputs // Create a term symbol for a function defined in the given module private def createFunSymbolInMod(name: Str, paramNames: List[Str], mod: ModuleOrObjectSymbol) = - val sym = TermSymbol(syntax.Fun, N, Ident(name)) + val sym = TermSymbol(syntax.Fun, N, Ident(name), erasedType = N) val bsym = BlockMemberSymbol(name, Nil, true) - val ps = PlainParamList(paramNames.map(s => Param.simple(VarSymbol(Ident(s))))) + val ps = PlainParamList(paramNames.map(s => Param.simple(VarSymbol(Ident(s), erasedType = N)))) sym.defn = S(TermDefinition(syntax.Fun, bsym, sym, ps :: Nil, N, N, N, TermDefFlags(true), Modulefulness(S(mod))(false), Nil, N)) sym @@ -532,10 +534,10 @@ extends Importer with ucs.SplitElaborator: )(using State): Term = val clsSym = ClassSymbol(DummyTypeDef(Cls), Ident(effectClassName)) val htds = methods.map: spec => - val valueSym = spec.valueParamName.map(nme => VarSymbol(Ident(nme))) - val resumeSym = VarSymbol(Ident("resume")) + val valueSym = spec.valueParamName.map(nme => VarSymbol(Ident(nme), erasedType = N)) + val resumeSym = VarSymbol(Ident("resume"), erasedType = N) val mtdSym = BlockMemberSymbol(spec.methodName, Nil, true) - val tsym = TermSymbol(Fun, N, Ident(spec.methodName)) + val tsym = TermSymbol(Fun, N, Ident(spec.methodName), erasedType = N) val td = TermDefinition( Fun, mtdSym, @@ -670,7 +672,7 @@ extends Importer with ucs.SplitElaborator: Term.Error else val lt = subterm(lhs) - val sym = TempSymbol(S(lt), "old") + val sym = TempSymbol(S(lt), erasedType = N, "old") Blk( LetDecl(sym, Nil) :: DefineVar(sym, lt) :: Nil, Term.Try(Blk( Term.Assgn(lt, subterm(rhs)) :: Nil, @@ -748,20 +750,21 @@ extends Importer with ucs.SplitElaborator: })(N).withLocOf(tree) case InfixApp(TyTup(tvs), Keywrd(Keyword.`->`), body) => val boundVars = mutable.HashMap.empty[Str, VarSymbol] - def genSym(id: Tree.Ident) = - val sym = VarSymbol(id) + def genSym(id: Tree.Ident, erasedType: Opt[ErasedType]) = + val sym = VarSymbol(id, erasedType) sym.decl = S(TyParam(FldFlags.empty, N, sym)) // TODO vce boundVars += id.name -> sym sym val syms = (tvs.collect: - case id: Tree.Ident => (genSym(id), N, N) - case InfixApp(id: Tree.Ident, Keywrd(Keyword.`extends`), ub) => (genSym(id), S(ub), N) - case InfixApp(id: Tree.Ident, Keywrd(Keyword.`restricts`), lb) => (genSym(id), N, S(lb)) - case InfixApp(InfixApp(id: Tree.Ident, Keywrd(Keyword.`extends`), ub), Keywrd(Keyword.`restricts`), lb) => (genSym(id), S(ub), S(lb)) + case id: Tree.Ident => (genSym(id, erasedType = N), N, N) + case InfixApp(id: Tree.Ident, Keywrd(Keyword.`extends`), ub) => (genSym(id, erasedType = N), S(ub), N) + case InfixApp(id: Tree.Ident, Keywrd(Keyword.`restricts`), lb) => (genSym(id, erasedType = N), N, S(lb)) + case InfixApp(InfixApp(id: Tree.Ident, Keywrd(Keyword.`extends`), ub), Keywrd(Keyword.`restricts`), lb) => + (genSym(id, erasedType = N), S(ub), S(lb)) ) val outer = (tvs.collect: - case Outer(S(name: Tree.Ident)) => genSym(name) - case Outer(N) => genSym(Tree.Ident("outer")) + case Outer(S(name: Tree.Ident)) => genSym(name, erasedType = N) + case Outer(N) => genSym(Tree.Ident("outer"), erasedType = N) ) match case ot :: Nil => S(ot) case _ :: rest => @@ -943,8 +946,8 @@ extends Importer with ucs.SplitElaborator: msg" add a space before ‹identifier› to make it an operator application." -> N :: Nil N - val self = VarSymbol(Ident("self")) - val args = VarSymbol(Ident("args")) + val self = VarSymbol(Ident("self"), erasedType = N) + val args = VarSymbol(Ident("args"), erasedType = N) val ps = ParamList(ParamListFlags.empty, Param(FldFlags.empty, self, N, Modulefulness.none) :: Nil, S: @@ -1017,7 +1020,7 @@ extends Importer with ucs.SplitElaborator: case Quoted(body) => Term.Quoted(subterm(body)) case Unquoted(body) => Term.Unquoted(subterm(body)) case tree @ Case(kw, _) => - val scrut = VarSymbol(Ident("caseScrut")) + val scrut = VarSymbol(Ident("caseScrut"), erasedType = N) val body = caseSplit(scrut, tree) val params = Param(FldFlags.empty, scrut, N, Modulefulness.none) :: Nil Term.Lam(PlainParamList(params), body).mkLocWith(kw) @@ -1046,10 +1049,10 @@ extends Importer with ucs.SplitElaborator: Term.Throw(subterm(body)).mkLocWith(kw) case PrefixApp(kw @ Keywrd(Keyword.`do`), InfixApp(labelId: Ident, Keywrd(Keyword.`:`), body)) => val labelSym = new LabelSymbol(N, labelId.name) - val resultSym = new TempSymbol(N, s"${labelId.name}$$result") - val nonLocalHandlerSym = TempSymbol(N, s"nonLocalHandler$$${labelId.name}") - val nonLocalBreakMethodMarker = TempSymbol(N, s"nonLocalBreakMethod$$${labelId.name}") - val nonLocalContinueMethodMarker = TempSymbol(N, s"nonLocalContinueMethod$$${labelId.name}") + val resultSym = new TempSymbol(N, erasedType = N, s"${labelId.name}$$result") + val nonLocalHandlerSym = TempSymbol(N, erasedType = N, s"nonLocalHandler$$${labelId.name}") + val nonLocalBreakMethodMarker = TempSymbol(N, erasedType = N, s"nonLocalBreakMethod$$${labelId.name}") + val nonLocalContinueMethodMarker = TempSymbol(N, erasedType = N, s"nonLocalContinueMethod$$${labelId.name}") val bodyTerm = ctx.withLabel( labelSym, resultSym, nonLocalHandlerSym, nonLocalBreakMethodMarker, nonLocalContinueMethodMarker).givenIn: subterm(body) @@ -1061,7 +1064,7 @@ extends Importer with ucs.SplitElaborator: case PrefixApp(kw @ Keywrd(Keyword.`drop`), body) => Term.Drop(subterm(body)).mkLocWith(kw) case Region(id: Ident, body) => - val sym = VarSymbol(id) + val sym = VarSymbol(id, erasedType = N) given Ctx = ctx + (id.name -> sym) Term.Region(sym, subterm(body)) case RegRef(reg, value) => Term.RegRef(subterm(reg), subterm(value)) @@ -1147,7 +1150,7 @@ extends Importer with ucs.SplitElaborator: raise(ErrorReport(msg"Illegal position for '_' placeholder." -> tree.toLoc :: Nil)) Term.Error case S(unds) => - val sym = VarSymbol(Ident("_" + unds.size)) + val sym = VarSymbol(Ident("_" + unds.size), erasedType = N) unds += sym sym.ref() case Annotated(lhs, rhs) => @@ -1390,7 +1393,7 @@ extends Importer with ucs.SplitElaborator: (base, term(rrhs)) val newAcc = rlhs match case id: Ident => - val sym = new VarSymbol(id) + val sym = new VarSymbol(id, erasedType = N) newCtx += id.name -> sym RcdField(Term.Lit(StrLit(id.name)).withLocOf(id), sym.ref(id)) :: DefineVar(sym, rhs_t) @@ -1489,7 +1492,7 @@ extends Importer with ucs.SplitElaborator: case N => N case _ if ctx.mode is Mode.Light => S(Term.Missing) case S(rhs) => S: - val nonLocalRetHandler = TempSymbol(N, s"nonLocalRetHandler$$${id.name}") + val nonLocalRetHandler = TempSymbol(N, erasedType = N, s"nonLocalRetHandler$$${id.name}") newCtx.nest(OuterCtx.Function(nonLocalRetHandler)).givenIn: newCtx ?=> val b = term(rhs)(using newCtx) if nonLocalRetHandler.directRefs.isEmpty then b else @@ -1510,7 +1513,7 @@ extends Importer with ucs.SplitElaborator: case _ => Modulefulness.none - val tsym = TermSymbol(k, owner, id) // TODO? + val tsym = TermSymbol(k, owner, id, erasedType = N) // TODO? val tdf = TermDefinition(k, sym, tsym, pss, tps, s, body, TermDefFlags.empty.copy(isMethod = isMethod), mfn, annotations, N).withLocOf(td) tsym.defn = S(tdf) @@ -1562,7 +1565,7 @@ extends Importer with ucs.SplitElaborator: (id, S(false)) case Modified(Keywrd(Keyword.`out`), id: Ident) => (id, S(true)) - val vs = VarSymbol(id) + val vs = VarSymbol(id, erasedType = N) val res = TyParam(FldFlags.empty, vce, vs) vs.decl = S(res) res :: Nil @@ -1624,7 +1627,7 @@ extends Importer with ucs.SplitElaborator: tsym.defn = S(fdef) fdef :: Nil else - val psym = TermSymbol(LetBind, owner, p.sym.id) + val psym = TermSymbol(LetBind, owner, p.sym.id, erasedType = N) val decl = LetDecl(psym, Nil) val defn = DefineVar(psym, p.sym.ref()) p.fldSym = S(psym) @@ -1636,7 +1639,7 @@ extends Importer with ucs.SplitElaborator: val owner = td.symbol match case s: InnerSymbol => S(s) case _: TypeAliasSymbol => die - val psym = TermSymbol(LetBind, owner, p.sym.id) + val psym = TermSymbol(LetBind, owner, p.sym.id, erasedType = N) val decl = LetDecl(psym, Nil) val defn = DefineVar(psym, p.sym.ref()) p.fldSym = S(psym) @@ -1867,15 +1870,15 @@ extends Importer with ucs.SplitElaborator: case N => N def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): TermSymbol | VarSymbol = - if ctx.outer.inner.isDefined then TermSymbol(k, ctx.outer.inner, id) - else VarSymbol(id) + if ctx.outer.inner.isDefined then TermSymbol(k, ctx.outer.inner, id, erasedType = N) + else VarSymbol(id, erasedType = N) def param(t: Tree, inUsing: Bool, inDataClass: Bool): Ctxl[Diagnostic \/ (Param, Opt[SpreadKind])] = t.desugared.asParam(inUsing).map: case pt @ ParamTree(flags, id, sign, spd, modifiers) => log(s"Elaborating ParamTree: ${pt}") val flg = flags.copy(isVal = flags.isVal || inDataClass) - val sym = VarSymbol(id) + val sym = VarSymbol(id, erasedType = N) val sig = sign.map(term(_)) val p = Param(flg, sym, sig, Modulefulness.ofSign(sig)(Mod in modifiers)) sym.decl = S(p) @@ -2001,7 +2004,7 @@ extends Importer with ucs.SplitElaborator: // We create a symbol specifically for `Param` for each variable to // avoid redundantly redeclaring symbols in `Scope` during code // generation, which triggers the assertion in `Scope.addToBindings`. - val parameterSymbol = VarSymbol(new Ident(symbol.name)) + val parameterSymbol = VarSymbol(new Ident(symbol.name), erasedType = N) (name -> parameterSymbol, symbol -> parameterSymbol) .toList.unzip pattern.variables.report // Report all invalid variables we found in `pattern`. @@ -2138,7 +2141,7 @@ extends Importer with ucs.SplitElaborator: case TyTup(ps) => val vs = ps.flatMap: case id: Ident => - val sym = VarSymbol(id) + val sym = VarSymbol(id, erasedType = N) sym.decl = S(TyParam(FldFlags.empty, N, sym)) Param(FldFlags.empty, sym, N, Modulefulness.none) :: Nil case t => diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala index f48590b82b..2acd387245 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala @@ -30,7 +30,7 @@ class Importer: val nme = file.baseName val id = alias.getOrElse(new syntax.Tree.Ident(nme)) // TODO loc - lazy val sym = TermSymbol(LetBind, N, id) + lazy val sym = TermSymbol(LetBind, N, id, erasedType = N) if path.startsWith(".") || path.startsWith("/") then // leave alone imports like "fs" log(s"importing $file") diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala index 1384866505..8b1c39d97b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala @@ -53,7 +53,7 @@ object Pattern: // TODO: The above edge case would fail the following assertion. assert(symbols.size <= 1) // If no symbol had been created before, create a new symbol now. - val symbol = symbols.headOption.getOrElse(VarSymbol(Ident(name))) + val symbol = symbols.headOption.getOrElse(VarSymbol(Ident(name), erasedType = N)) aliases.foreach: alias => // For guarded patterns (`p where t`), the variables in `p` have to be // allocated before `t` is elaborated. In that case, we don't need to diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index a630b81e7e..e07dfb039d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -10,6 +10,7 @@ import hkmc2.utils.* import Elaborator.State import Tree.Ident +import hkmc2.codegen.{ErasedType, HasErasedType, erasedType} import hkmc2.utils.SymbolSubst @@ -185,7 +186,7 @@ object FlowSymbol: end FlowSymbol -sealed trait LocalVarSymbol extends LocalSymbol +sealed trait LocalVarSymbol extends LocalSymbol with HasErasedType sealed trait LocalSymbol extends Symbol: def subst(using s: SymbolSubst): LocalSymbol sealed trait NamedSymbol extends Symbol: @@ -213,7 +214,9 @@ abstract class BlockLocalSymbol(name: Str)(using State) extends FlowSymbol(name) self: LocalSymbol => // * using `with LocalSymbol` in the `extends` clause makes Scala think there's a bad override var decl: Opt[Declaration] = N -class TempSymbol(val trm: Opt[Term], dbgNme: Str = "tmp")(using State) extends BlockLocalSymbol(dbgNme) with LocalVarSymbol: +class TempSymbol + (val trm: Opt[Term], override val erasedType: Opt[ErasedType], dbgNme: Str = "tmp")(using State) + extends BlockLocalSymbol(dbgNme) with LocalVarSymbol: // val nameHints: MutSet[Str] = MutSet.empty // * May be useful later? override def toLoc: Option[Loc] = trm.flatMap(_.toLoc) override def prefix: Str = "tmp:" @@ -229,15 +232,15 @@ class InstSymbol(val origin: Symbol)(using State) extends LocalSymbol: def subst(using sub: SymbolSubst): InstSymbol = sub.mapInstSym(this) -class VarSymbol(val id: Ident)(using State) extends BlockLocalSymbol(id.name) with NamedSymbol with LocalVarSymbol: +class VarSymbol(val id: Ident, override val erasedType: Opt[ErasedType])(using State) extends BlockLocalSymbol(id.name) with NamedSymbol with LocalVarSymbol: val name: Str = id.name override def toLoc: Opt[Loc] = id.toLoc // override def toString: Str = s"$name@$uid" override def subst(using s: SymbolSubst): VarSymbol = s.mapVarSym(this) class BuiltinSymbol - (val nme: Str, val binary: Bool, val unary: Bool, val nullary: Bool, val functionLike: Bool)(using State) - extends Symbol: + (val nme: Str, val binary: Bool, val unary: Bool, val nullary: Bool, val functionLike: Bool, override val erasedType: Opt[ErasedType])(using State) + extends Symbol with HasErasedType: def toLoc: Option[Loc] = N override def prefix: Str = "builtin:" @@ -314,7 +317,7 @@ sealed abstract class MemberSymbol(using State) extends Symbol: def subst(using SymbolSubst): MemberSymbol -class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.Ident)(using State) +class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.Ident, override val erasedType: Opt[ErasedType])(using State) extends MemberSymbol with DefinitionSymbol[TermDefinition] with LocalVarSymbol @@ -335,13 +338,13 @@ class ClassCtorSymbol( override val k: syntax.Fun.type, override val owner: S[ClassSymbol], id: Tree.Ident -)(using State) extends TermSymbol(k, owner, id): +)(using State) extends TermSymbol(k, owner, id, N): override def subst(using sub: SymbolSubst): ClassCtorSymbol = sub.mapClassCtorSym(this) object TermSymbol: def fromFunBms(b: BlockMemberSymbol, owner: Opt[InnerSymbol])(using State) = - TermSymbol(syntax.Fun, owner, Tree.Ident(b.nme)) + TermSymbol(syntax.Fun, owner, Tree.Ident(b.nme), N) sealed trait CtorSymbol extends Symbol: @@ -353,7 +356,8 @@ case class Extr(isTop: Bool)(using State) extends CtorSymbol: def toLoc: Option[Loc] = N override def toString: Str = nme -sealed abstract case class LitSymbol(lit: Literal)(using State) extends CtorSymbol: +sealed abstract case class LitSymbol(lit: Literal)(using State) extends CtorSymbol, HasErasedType: + override val erasedType: Opt[ErasedType] = S(lit.erasedType) def nme: Str = lit.idStr def toLoc: Option[Loc] = lit.toLoc override def prefix: Str = "lit:" @@ -378,7 +382,7 @@ case class ErrorSymbol(val nme: Str, tree: Tree)(using State) extends MemberSymb override def subst(using sub: SymbolSubst): ErrorSymbol = sub.mapErrorSym(this) override def prefix: Str = "error:" -sealed trait ClassLikeSymbol extends IdentifiedSymbol: +sealed trait ClassLikeSymbol extends IdentifiedSymbol, HasErasedType: self: MemberSymbol & DefinitionSymbol[? <: ClassDef | ModuleOrObjectDef] => val tree: Tree.TypeDef def subst(using sub: SymbolSubst): ClassLikeSymbol @@ -433,7 +437,8 @@ sealed trait InnerSymbol(using State) extends Symbol: // ensure that any implementation of InnerSymbol is also a DefinitionSymbol. self: DefinitionSymbol[? <: ClassLikeDef] => val privatesScope: Scope = Scope.empty(Scope.Cfg.default) // * Scope for private members of this symbol - val thisProxy: TempSymbol = TempSymbol(N, s"this$$$nme") + // TODO(Derppening): Can we meaningfully infer the erased type of `this` from the definition? + val thisProxy: TempSymbol = TempSymbol(N, erasedType = N, s"this$$$nme") def subst(using SymbolSubst): InnerSymbol def asDefnSym: DefinitionSymbol[? <: ClassLikeDef] & InnerSymbol = this match case d: DefinitionSymbol[? <: ClassLikeDef] => d @@ -449,6 +454,8 @@ class ClassSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using State) with InnerSymbol with NamedSymbol: + override val erasedType: Opt[ErasedType] = S(ErasedType.AnyRef(rsc = false, this)) + def name: Str = nme def nme = id.name def toLoc: Option[Loc] = id.toLoc // TODO track source tree of classe here @@ -465,6 +472,9 @@ class ModuleOrObjectSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using Sta with DefinitionSymbol[ModuleOrObjectDef] with InnerSymbol with NamedSymbol: + + override val erasedType: Opt[ErasedType] = S(ErasedType.AnyRef(rsc = false, this)) + def name: Str = nme def nme = id.name def toLoc: Option[Loc] = id.toLoc // TODO track source tree of module here @@ -476,7 +486,11 @@ class ModuleOrObjectSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using Sta class TypeAliasSymbol(val id: Tree.Ident)(using State) extends MemberSymbol - with DefinitionSymbol[TypeDef]: + with DefinitionSymbol[TypeDef] + with HasErasedType: + + override val erasedType: Opt[ErasedType] = irClsLikeDefn.flatMap(_.sym.asClsOrMod).map(ErasedType.AnyRef(rsc = false, _)) + def nme = id.name def toLoc: Option[Loc] = id.toLoc // TODO track source tree of type alias here override def prefix: Str = "type:" diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index 98424e1d2d..d191aff03f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -367,7 +367,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C for (_, s) <- entries do LoweringCtx.loweringCtx.collectScopedSym(s) val objectSym = ctx.builtins.Object mkMatch( // checking that we have an object - Case.Cls(objectSym, BuiltinSymbol(objectSym.nme, false, false, true, false).asSimpleRef), + Case.Cls(objectSym, BuiltinSymbol(objectSym.nme, false, false, true, false, erasedType = S(ErasedType.objectRef)).asSimpleRef), entries.foldRight(lowerSplit(tail, cont)): case ((fieldName, fieldSymbol), blk) => mkMatch( @@ -399,7 +399,8 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C // the Label body into the rest. Wrap with an exit label and temp variable so every path stores its // result, breaks to exitLabel, then the original cont runs once. val exitLabel = new LabelSymbol(N, sym.nme + "$x") - val tmp = new TempSymbol(N) + // TODO(Derppening): Change this to `r.erasedType` when `Result.erasedType` is implemented + val tmp = new TempSymbol(N, erasedType = N) LoweringCtx.loweringCtx.collectScopedSym(tmp) val exitCont: Result => Block = r => Assign(tmp, r, Break(exitLabel)) val bodyBlock = lowerSplit(sym.body, exitCont) @@ -446,7 +447,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C // 3. The term is a `while` and the result is used. lazy val l = usesResTmp = true - val res = new TempSymbol(t) + val res = new TempSymbol(t, erasedType = N) outerCtx.collectScopedSym(res) res // The symbol for the loop label if the term is a `while`. @@ -496,8 +497,8 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C // NOTE: `shouldRewriteWhile` is not the same as `config.rewriteWhileLoops` // as shouldRewriteWhile is always true when effect handler lowering is on if config.shouldRewriteWhile then - val loopResult = TempSymbol(N) - val isReturned = TempSymbol(N) + val loopResult = TempSymbol(N, erasedType = N) + val isReturned = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool))) outerCtx.collectScopedSym(loopResult) outerCtx.collectScopedSym(isReturned) val loopEnd: Path = diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/SplitElaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/SplitElaborator.scala index 4f68072122..956bfc5364 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/SplitElaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/SplitElaborator.scala @@ -149,12 +149,12 @@ trait SplitElaborator: private def branch(using Ctx): Cfg[PartialFunction[Tree, (Ctx, SimpleSplit)]] = // Interleaved-`let` bindings like `{ x is A then 0; let x = 1; ... }`. case LetLike(Keywrd(`let`), ident: Ident, S(rhsTree), N) => - val symbol = VarSymbol(ident) + val symbol = VarSymbol(ident, erasedType = N) val head = Head.Let(symbol, term(rhsTree)) ((ctx + (ident.name -> symbol)), head ~: End) // Interleaved-`do` statements like `{ x is A then 0; do log(1); ... }`. case PrefixApp(Keywrd(`do`), rhsTree) => - (ctx, Head.Let(TempSymbol(N, "unused"), term(rhsTree)) ~: End) + (ctx, Head.Let(TempSymbol(N, erasedType = N, "unused"), term(rhsTree)) ~: End) // Although the `else`-clause marks the end of the split, we cannot // stop and still have to elaborate the remaining trees. case PrefixApp(kwTree @ Keywrd(`else`), elseTree) => @@ -254,7 +254,7 @@ trait SplitElaborator: case Term.Ref(symbol) => continuation(() => symbol.ref()) // Otherwise, we need to create a temporary symbol holding the term. case term: Term => - val symbol = TempSymbol(N, "scrut") + val symbol = TempSymbol(N, erasedType = N, "scrut") Head.Let(symbol, term) ~: continuation(() => symbol.ref()) private type TT = (Tree, Tree) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/TermSynthesizer.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/TermSynthesizer.scala index 474595f1f4..9aef64ac4b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/TermSynthesizer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/TermSynthesizer.scala @@ -93,11 +93,11 @@ trait TermSynthesizer(using State): app(stringLeave, tup(fld(t), fld(int(n))), label) protected final def tempLet(dbgName: Str, term: Term)(inner: TempSymbol => Split): Split = - val s = TempSymbol(N, dbgName) + val s = TempSymbol(N, erasedType = N, dbgName) Split.Let(s, term, inner(s)) protected final def plainTest(cond: Term, dbgName: Str = "cond")(inner: => Split): Split = - val s = TempSymbol(N, dbgName) + val s = TempSymbol(N, erasedType = N, dbgName) Split.Let(s, cond, Branch(s.safeRef, inner) ~: Split.End) protected final def makeBindings(fields: Ls[RcdField | RcdSpread]): Term = diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/Compiler.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/Compiler.scala index b8dfea4c78..42eb0c3722 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/Compiler.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/Compiler.scala @@ -131,7 +131,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter // by the set of labels (orders are not important). val labels = patterns.map(_.label) multiMatchers.get(labels).getOrElse: - val f = TempSymbol(N, makeMultiMatcherName(patterns)) + val f = TempSymbol(N, erasedType = N, makeMultiMatcherName(patterns)) multiMatchers += (labels -> f) buildQueue enqueue (f -> patterns) f // Return the symbol of the built function. @@ -147,7 +147,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val expandedPatterns = patterns.map(p => (p.label, p.expand(Set.empty))) val heads = expandedPatterns.flatMap((_, p) => p.heads).toList // This is the parameter of the current multi-matcher. - val scrutinee = VarSymbol(Ident("input")) + val scrutinee = VarSymbol(Ident("input"), erasedType = N) // Assemble branches for constructors and literals. val branches = heads.map: head => // Weird. Removing type annotations caused type errors. @@ -179,7 +179,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val subPatternsByField = Map.from(fields.map: field => field -> patterns.flatMap((_, p) => p.collectSubPatterns(field))) val subScrutinees = Map.from(subPatternsByField.map: (field, subPatterns) => - field -> MatcherResult(VarSymbol(field.asIdent), subPatterns.map(_.label))) + field -> MatcherResult(VarSymbol(field.asIdent, erasedType = N), subPatterns.map(_.label))) // Let bindings that bind the sub-scrutinee to the result of each matcher. val bindings = subPatternsByField.iterator.flatMap: (field, subPatterns) => val subScrutinee = subScrutinees(field) @@ -189,7 +189,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val conditional = // Check the presence of the field, and call the matcher if it exists. val fieldIdent: Ident = field.asIdent - val fieldSymbol = TempSymbol(N, fieldIdent.name) + val fieldSymbol = TempSymbol(N, erasedType = N, fieldIdent.name) val fieldTest = FlatPattern.Record((fieldIdent -> fieldSymbol) :: Nil) val consequent = Split.Else: val resultTerm = app(subMatcherSymbol.safeRef, tup(fld(fieldSymbol.safeRef)), "result") @@ -206,7 +206,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val z = (Nil: Ls[Statement], Nil: Ls[(Label, Term)]) val (tests, resultTerms) = patterns.iterator.foldLeft(z): case ((stmts, results), (label, pattern)) => - val symbol = TempSymbol(N, label.asFieldName + "$") + val symbol = TempSymbol(N, erasedType = N, label.asFieldName + "$") val makeSplit = completePattern(pattern, scrutinee, subScrutinees, Nil) val split = makeSplit( // There is no topmost transform here, so we emit the direct success @@ -261,7 +261,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter case N => Split.Else: makeMatchSuccess(output, nullifyEmptyBindings(bindings)) case S(transform) => - val resultSymbol = TempSymbol(N, "transformResult") + val resultSymbol = TempSymbol(N, erasedType = N, "transformResult") val bindingsTerm = nullifyEmptyBindings(bindings) val transformTerm = app(transform.safeRef, tup(fld(bindingsTerm)), "the transform's result") Split.Let(resultSymbol, transformTerm, Split.Else( @@ -280,7 +280,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter case ((field, pattern), makeInnerSplit) => val label = pattern.label val target = subScrutinees(field).select(label) - val resultSymbol = TempSymbol(N, s"result$label$$") + val resultSymbol = TempSymbol(N, erasedType = N, s"result$label$$") (makeConsequent, alternative) => Split.Let(resultSymbol, target, Branch(resultSymbol.safeRef, @@ -321,10 +321,10 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter case ((field, pattern), makeInnerSplit) => val label = pattern.label val target = subScrutinees(field).select(label) - val resultSymbol = TempSymbol(N, s"result$label$$") - val outputSymbol = TempSymbol(N, s"output$label$$") + val resultSymbol = TempSymbol(N, erasedType = N, s"result$label$$") + val outputSymbol = TempSymbol(N, erasedType = N, s"output$label$$") val fieldAliases = pattern.aliases - val fieldBindingsSymbol = TempSymbol(N, "fieldBindings") + val fieldBindingsSymbol = TempSymbol(N, erasedType = N, "fieldBindings") val fieldBindingsTerm = makeBindings(fieldAliases.map: alias => RcdField(str(alias.name), outputSymbol.safeRef)) val fieldOutput = @@ -332,7 +332,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter else subScrutinees(field).input (outputFields: Ls[(Ident, Term)], bindingsSymbols: Ls[TempSymbol]) => ((makeConsequent, alternative) => - val bindingsSymbol = TempSymbol(N, "bindings") + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings") val accumulatedBindings = val withFieldAliases = if fieldAliases.isEmpty then bindingsSymbols @@ -356,7 +356,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter case ((field, pattern), makeInnerSplit) => val label = pattern.label val target = subScrutinees(field).select(label) - val resultSymbol = TempSymbol(N, s"result$label$$") + val resultSymbol = TempSymbol(N, erasedType = N, s"result$label$$") (makeConsequent, alternative) => Split.Let(resultSymbol, target, Branch(resultSymbol.safeRef, @@ -395,18 +395,18 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val label = pattern.label val target = subScrutinees(field).select(label) // This is the symbol for `MatchSuccess`. - val resultSymbol = TempSymbol(N, s"result$label$$") + val resultSymbol = TempSymbol(N, erasedType = N, s"result$label$$") // This is the symbol for the output of the pattern. - val outputSymbol = TempSymbol(N, s"output$label$$") + val outputSymbol = TempSymbol(N, erasedType = N, s"output$label$$") val outputField = RcdField(str(field.name), outputSymbol.safeRef) // This is the bindings of the current field. val fieldAliases = pattern.aliases - val fieldBindingsSymbol = TempSymbol(N, "fieldBindings") + val fieldBindingsSymbol = TempSymbol(N, erasedType = N, "fieldBindings") val fieldBindingsTerm = makeBindings(fieldAliases.map: alias => RcdField(str(alias.name), outputSymbol.safeRef)) (outputFields: Ls[RcdField], bindingsSymbols: Ls[TempSymbol]) => ((makeConsequent, alternative) => - val bindingsSymbol = TempSymbol(N, "bindings") + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings") val accumulatedBindings = val withFieldAliases = if fieldAliases.isEmpty then bindingsSymbols @@ -463,9 +463,9 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter (makeConsequent, alternative) => // Here, we need to make a tuple of all output values of patterns. // Then we merge the bindings of all patterns. - val outputSymbol = TempSymbol(N, "combinedOutput") + val outputSymbol = TempSymbol(N, erasedType = N, "combinedOutput") val outputTerm = tup(allOutputs.reverseIterator.map(_.use |> fld).toSeq*) - val bindingsSymbol = TempSymbol(N, "combinedBindings") + val bindingsSymbol = TempSymbol(N, erasedType = N, "combinedBindings") // I think the bindings do not need to be reversed. val bindingsTerm = makeBindings(allBindings.map: binding => RcdSpread(binding.use)) @@ -502,9 +502,9 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter else // The symbol representing the transform function, which should be // declared at the outermost level. - val transformSymbol = TempSymbol(N, "transform") + val transformSymbol = TempSymbol(N, erasedType = N, "transform") // The transform function takes a single record as the argument. - val bindingsSymbol = VarSymbol(Ident("args")) + val bindingsSymbol = VarSymbol(Ident("args"), erasedType = N) val params = paramList(param(bindingsSymbol)) // Because we pass the extracted values using recoreds. We need to bind // each property to its corresponding variable which is accessible from @@ -525,10 +525,10 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val transformTerm = app(transformSymbol.safeRef, tup(fld(bindings.use)), "the transform's result") // Bind the transformation result to a new output symbol. - val resultSymbol = TempSymbol(N, "transformResult") + val resultSymbol = TempSymbol(N, erasedType = N, "transformResult") // Don't forget that current pattern may also have aliases which are // available in some outer transform patterns. - val currentBindingsSymbol = TempSymbol(N, "bindings") + val currentBindingsSymbol = TempSymbol(N, erasedType = N, "bindings") val currentBindings = makeBindings(aliases.map: alias => RcdField(str(alias.name), resultSymbol.safeRef)) Split.Let(resultSymbol, transformTerm, diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala index a9a2342871..593bb7cf8e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala @@ -7,6 +7,7 @@ import Message.MessageContext import ucs.{TermSynthesizer, FlatPattern, error, warn, safeRef}, ucs.extractors.* import syntax.{Fun, Keyword, Tree}, Tree.{Ident, StrLit}, Keyword.{`as`, `=>`} import collection.mutable.{Buffer, HashMap}, collection.immutable.SeqMap +import codegen.{ErasedType, PrimitiveType} import Elaborator.{Ctx, State, ctx}, utils.TL import semantics.Pattern as SP // "SP" is short for "semantic patterns" import Term.Ref @@ -28,14 +29,14 @@ object SplitCompiler: def getSubScrutinee(cs: ClassSymbol | PatternSymbol)(i: Int)(using State): SymbolScrut[?] = val scrutinees = subScrutinees.getOrElseUpdate(cs, Buffer.empty) while scrutinees.size <= i do - scrutinees += TempSymbol(N, s"arg$$${cs.nme}$$${scrutinees.size}$$").toScrut + scrutinees += TempSymbol(N, erasedType = N, s"arg$$${cs.nme}$$${scrutinees.size}$$").toScrut scrutinees(i) def getTupleLeadSubScrutinee(index: Int)(using State): SymbolScrut[?] = - tupleLead.getOrElseUpdate(index, TempSymbol(N, s"element$index$$").toScrut) + tupleLead.getOrElseUpdate(index, TempSymbol(N, erasedType = N, s"element$index$$").toScrut) def getTupleLastSubScrutinee(index: Int)(using State): SymbolScrut[?] = - tupleLast.getOrElseUpdate(index, TempSymbol(N, s"lastElement$index$$").toScrut) + tupleLast.getOrElseUpdate(index, TempSymbol(N, erasedType = N, s"lastElement$index$$").toScrut) def getFieldScrutinee(fieldName: Ident)(using State): SymbolScrut[?] = - fields.getOrElseUpdate(fieldName, TempSymbol(N, s"field_${fieldName.name}$$").toScrut) + fields.getOrElseUpdate(fieldName, TempSymbol(N, erasedType = N, s"field_${fieldName.name}$$").toScrut) object Scrut: def from(ref: Term.Ref): RefScrut = RefScrut(() => ref) @@ -109,7 +110,7 @@ object SplitCompiler: private var _hasBeenUsed = false private lazy val symbol = _hasBeenUsed = true - TempSymbol(N, nameHint.getOrElse("output")) + TempSymbol(N, erasedType = N, nameHint.getOrElse("output")) def apply(): Ref = symbol.safeRef def toList: Ls[TempSymbol] = if _hasBeenUsed then symbol :: Nil else Nil def toLet(term: => Term, tail: Split): Split = @@ -312,7 +313,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz def makePatternBindings(using State): Ls[(TempSymbol, SP)] = patterns.iterator.zipWithIndex.map: - case (pattern, index) => new TempSymbol(N, s"patternArgument${index}$$") -> pattern + case (pattern, index) => new TempSymbol(N, erasedType = N, s"patternArgument${index}$$") -> pattern .toList /** @@ -580,11 +581,11 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz // Then, we can call the `unapply` method. val unapplyArguments = patternBindings.map(_._1.safeRef |> fld) :+ fld(scrutinee()) val unapplyCall = app(sel(patternTerm, "unapply").resolve, tup(unapplyArguments*), s"result of unapply") - val unapplyResult = TempSymbol(N, "unapplyResult") + val unapplyResult = TempSymbol(N, erasedType = N, "unapplyResult") val wrapUnapply = Split.Let(unapplyResult, unapplyCall, _) // Then, we need to destruct the produced `MatchSuccess`. - val outputSymbol = TempSymbol(N, "output").toScrut // TODO: We can use `LazyScrut` for this, but the match result pattern's parameters requires a symbol. - val bindingsSymbol = TempSymbol(N, "bindings") // TODO: This can be automatically removed when no transformation is used. + val outputSymbol = TempSymbol(N, erasedType = N, "output").toScrut // TODO: We can use `LazyScrut` for this, but the match result pattern's parameters requires a symbol. + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings") // TODO: This can be automatically removed when no transformation is used. val wrapDestruction = (consequent: Split) => val pattern = matchSuccessPattern(S(outputSymbol.symbol :: bindingsSymbol :: Nil)) Branch(unapplyResult.safeRef, pattern, consequent) ~: alternative @@ -632,8 +633,8 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val unapplyTerm = sel(parameterTerm, "unapply").resolve val unapplyCall = app(unapplyTerm, tup(fld(scrutinee())), s"result of unapply") tempLet(s"matchSuccess_${parameterSymbol.name}", unapplyCall): matchSuccessSymbol => - val outputSymbol = TempSymbol(N, "output").toScrut - val bindingsSymbol = TempSymbol(N, "bindings").toScrut + val outputSymbol = TempSymbol(N, erasedType = N, "output").toScrut + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings").toScrut val pattern = matchSuccessPattern(S(Ls(outputSymbol.symbol, bindingsSymbol.symbol))) val consequent = makeConsequent(outputSymbol, SeqMap.empty) Branch(matchSuccessSymbol.safeRef, pattern, consequent) ~: alternative @@ -730,7 +731,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val trailingSize = trailing.size val (trailSubScrutinees, makeConsequent0) = trailing.folded(makeConsequent): index => scrutinee.getTupleLastSubScrutinee(trailingSize - index) - val spreadSubScrutinee = TempSymbol(N, "middleElements") + val spreadSubScrutinee = TempSymbol(N, erasedType = N, "middleElements") val makeConsequent1: MakeConsequent = (outerOutput, outerBindings) => makeMatchSplit(spreadSubScrutinee.toScrut, spread, false)( (spreadOutput, spreadBindings) => makeConsequent0( @@ -793,7 +794,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val params = parameters.map: case (_, parameterSymbol) => Param(FldFlags.empty, parameterSymbol, N, Modulefulness.none) - val lambdaSymbol = new TempSymbol(N, "transform") + val lambdaSymbol = new TempSymbol(N, erasedType = N, "transform") // Next, we need to elaborate the pattern into a split. Note that // `makeMatchSplit` returns a function that takes a split as the // consequence. `makeMatchSplit` also takes a list of symbols so that @@ -811,7 +812,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz (_output, bindings) => val arguments = symbols.iterator.map(bindings).map(_() |> fld).toSeq val resultTerm = app(lambdaSymbol.safeRef, tup(arguments*), "the transform's result") - val resultSymbol = TempSymbol(N, "transformResult") + val resultSymbol = TempSymbol(N, erasedType = N, "transformResult") Split.Let(resultSymbol, resultTerm, makeConsequent(resultSymbol.toScrut, SeqMap.empty)), alternative)) case Annotated(pattern, annotations) => @@ -837,7 +838,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz case Guarded(pattern, guard) => (makeConsequent, alternative) => makeMatchSplit(scrutinee, pattern, true)( (output, bindings) => - val guardSymbol = TempSymbol(N, "guardResult") + val guardSymbol = TempSymbol(N, erasedType = N, "guardResult") val branch = Branch(guardSymbol.ref(), makeConsequent(output, bindings)) val innermost = Split.Let(guardSymbol, guard, branch ~: Split.End) // The creation of bindings here is repeated with the creation of @@ -858,8 +859,8 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz matchParametersWithArguments(defn, arguments) warnOnDiscardedExtractionOutputs(patternSymbol, extractionMatches) if shouldReject then RejectPrefixSplit else (makeConsequent, alternative) => - val outputSymbol = TempSymbol(N, "output").toScrut - val remainingSymbol = TempSymbol(N, "remaining").toScrut + val outputSymbol = TempSymbol(N, erasedType = N, "output").toScrut + val remainingSymbol = TempSymbol(N, erasedType = N, "remaining").toScrut val (extractionArguments, makeConsequentForArguments) = extractionMatches.fold((N, makeConsequent)): _.folded(makeConsequent)(scrutinee.getSubScrutinee(patternSymbol)).mapFirst(S(_)) @@ -870,8 +871,8 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val args = tup(patternBindings.map(_._1.safeRef) :+ scrutinee()) app(sel(patternTerm, method), args, s"result of $method") val split = tempLet("unapplyResult", unapplyCall): matchSuccessSymbol => - val outputPairSymbol = TempSymbol(N, "outputPair") - val bindingsSymbol = TempSymbol(N, "bindings") + val outputPairSymbol = TempSymbol(N, erasedType = N, "outputPair") + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings") Branch( matchSuccessSymbol.safeRef, matchSuccessPattern(S(outputPairSymbol :: bindingsSymbol :: Nil)), @@ -896,12 +897,12 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val unapplyCall = app(unapplyTerm, tup(fld(scrutinee())), s"result of unapply") tempLet(s"matchSuccess_${parameterSymbol.name}", unapplyCall): matchSuccessSymbol => // Destruct the `MatchSuccess` produced by the `unapply` method. - val outputPairSymbol = TempSymbol(N, "outputPair").toScrut - val bindingsSymbol = TempSymbol(N, "bindings").toScrut + val outputPairSymbol = TempSymbol(N, erasedType = N, "outputPair").toScrut + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings").toScrut val pattern = matchSuccessPattern(S(Ls(outputPairSymbol.symbol, bindingsSymbol.symbol))) // Destruct the first field of the `MatchSuccess` as a pair. - val outputSymbol = TempSymbol(N, "output").toScrut // Denotes the pattern's output. - val remainingSymbol = TempSymbol(N, "remaining").toScrut // Denotes the remaining value. + val outputSymbol = TempSymbol(N, erasedType = N, "output").toScrut // Denotes the pattern's output. + val remainingSymbol = TempSymbol(N, erasedType = N, "remaining").toScrut // Denotes the remaining value. // Assemble the `Split`s in order from inside to outside. val consequent1 = makeConsequent(outputSymbol, remainingSymbol, SeqMap.empty) val consequent2 = makeTupleBranch(outputPairSymbol(), @@ -933,15 +934,15 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz ) RejectPrefixSplit case _ => (makeConsequent, alternative) => - val nonEmptySymbol = TempSymbol(N, "nonEmpty") + val nonEmptySymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "nonEmpty") val nonEmptyTerm = app( this.lt.safeRef, tup(fld(int(0)), fld(sel(scrutinee(), "length"))), "string is not empty" ) - val outputSymbol = TempSymbol(N, "stringHead") + val outputSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "stringHead") val outputTerm = callStringGet(scrutinee(), 0, "head") - val remainsSymbol = TempSymbol(N, "stringTail") + val remainsSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "stringTail") val remainsTerm = callStringDrop(scrutinee(), 1, "tail") Split.Let(nonEmptySymbol, nonEmptyTerm, Branch(nonEmptySymbol.safeRef, @@ -965,7 +966,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz makeStringPrefixMatchSplit(scrutinee, left)( (leftOutput, leftRemains, leftBindings) => makeMatchSplit(scrutinee, right)( (rightOutput, rightBindings) => - val productSymbol = TempSymbol(N, "product") + val productSymbol = TempSymbol(N, erasedType = N, "product") val productTerm = tup(leftOutput() |> fld, rightOutput() |> fld) Split.Let(productSymbol, productTerm, makeConsequent( productSymbol.toScrut, leftRemains, leftBindings ++ rightBindings)), @@ -984,7 +985,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz case Wildcard() => (makeConsequent, alternative) => // Because the wildcard pattern always matches, we can match the entire // string and returns an empty string as the remaining value. - val emptyStringSymbol = TempSymbol(N, "emptyString") + val emptyStringSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "emptyString") makeConsequent(scrutinee, emptyStringSymbol.toScrut, SeqMap.empty) Branch(scrutinee(), FlatPattern.ClassLike(ctx.builtins.Str.safeRef, ctx.builtins.Str, N, false)(Tree.Dummy), Split.Let(emptyStringSymbol, str(""), @@ -993,12 +994,12 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz case Literal(prefix: StrLit) => (makeConsequent, alternative) => // Check if the scrutinee is the same as the literal. If so, we return // an empty string as the remaining value. - val isLeadingSymbol = TempSymbol(N, "isLeading") + val isLeadingSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "isLeading") val isLeadingTerm = callStringStartsWith( scrutinee(), Term.Lit(prefix), "the result of startsWith") - val outputSymbol = TempSymbol(N, "consumed") + val outputSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "consumed") val outputTerm = callStringTake(scrutinee(), prefix.value.length, "the consumed part of input") - val remainsSymbol = TempSymbol(N, "remains") + val remainsSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "remains") val remainsTerm = callStringDrop(scrutinee(), prefix.value.length, "the remaining input") Split.Let(isLeadingSymbol, isLeadingTerm, Branch(isLeadingSymbol.safeRef, @@ -1010,9 +1011,9 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz case Literal(_) => RejectPrefixSplit case Range(lower: StrLit, upper: StrLit, rightInclusive) => (makeConsequent, alternative) => // Check if the string is not empty. Then - val stringHeadSymbol = TempSymbol(N, "stringHead") - val stringTailSymbol = TempSymbol(N, "stringTail") - val nonEmptySymbol = TempSymbol(N, "nonEmpty") + val stringHeadSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "stringHead") + val stringTailSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "stringTail") + val nonEmptySymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "nonEmpty") val nonEmptyTerm = app(this.lt.safeRef, tup(fld(int(0)), fld(sel(scrutinee(), "length"))), "string is not empty") Split.Let(nonEmptySymbol, nonEmptyTerm, // `0 < string.length` Branch(nonEmptySymbol.safeRef, @@ -1071,7 +1072,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val params = parameters.map: case (_, parameterSymbol) => Param(FldFlags.empty, parameterSymbol, N, Modulefulness.none) - val lambdaSymbol = new TempSymbol(N, "transform") + val lambdaSymbol = new TempSymbol(N, erasedType = N, "transform") (makeConsequent, alternative) => Split.Let( sym = lambdaSymbol, term = Term.Lam(PlainParamList(params), transform), @@ -1081,7 +1082,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz (_output, remains, bindings) => val arguments = symbols.iterator.map(bindings).map(_() |> fld).toSeq val resultTerm = app(lambdaSymbol.safeRef, tup(arguments*), "the transform's result") - val resultSymbol = TempSymbol(N, "transformResult") + val resultSymbol = TempSymbol(N, erasedType = N, "transformResult") Split.Let(resultSymbol, resultTerm, makeConsequent(resultSymbol.toScrut, remains, SeqMap.empty)), alternative)) case Guarded(pattern, guard) => @@ -1089,7 +1090,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz make.whenAccept: (makeConsequent, alternative) => make( (output, remains, bindings) => - val guardSymbol = TempSymbol(N, "guardResult") + val guardSymbol = TempSymbol(N, erasedType = N, "guardResult") val branch = Branch(guardSymbol.ref(), makeConsequent(output, remains, bindings)) val innermost = Split.Let(guardSymbol, guard, branch ~: Split.End) // The creation of bindings here is repeated with the creation of @@ -1154,18 +1155,18 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val (matcherSymbol, implementations) = compiler.buildMatcher(synonym, resultMode) val innermostSplit = resultMode match case ResultMode.MatchOnly => - val resultSymbol = TempSymbol(N, "matchSuccess") + val resultSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "matchSuccess") val resultTerm = app(matcherSymbol.safeRef, tup(fld(scrutinee())), "result of matcher function") Split.Let(resultSymbol, resultTerm, Branch(resultSymbol.safeRef, makeConsequent(scrutinee, SeqMap.empty)) ~: alternative) case ResultMode.Full => // 1. Bind the call result to a variable. - val recordSymbol = TempSymbol(N, "matchRecord") + val recordSymbol = TempSymbol(N, erasedType = N, "matchRecord") val recordTerm = app(matcherSymbol.safeRef, tup(fld(scrutinee())), "result of matcher function") val f1 = Split.Let(recordSymbol, recordTerm, _) // 2. Check if the direct result is a `MatchSuccess` and bind the output. - val outputSymbol = TempSymbol(N, "patternOutput") - val bindingsSymbol = TempSymbol(N, "bindings") // TODO: This is useless. + val outputSymbol = TempSymbol(N, erasedType = N, "patternOutput") + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings") // TODO: This is useless. val consequent = makeConsequent(outputSymbol.toScrut, SeqMap.empty) val pattern = matchSuccessPattern(S(outputSymbol :: bindingsSymbol :: Nil)) val branch = Branch(recordSymbol.safeRef, pattern, consequent) @@ -1224,7 +1225,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz s"compilePattern >>> $blk" ): val unapply = scoped("ucs:translation"): - val inputSymbol = VarSymbol(Ident("input")) + val inputSymbol = VarSymbol(Ident("input"), erasedType = N) val topmost = makeMatchSplit(inputSymbol.toScrut, pd.pattern, true)( makeConsequent = (output, bindings) => def getBinding(p: Param) = bindings.get(p.sym).fold(Term.Error)(_()) @@ -1251,7 +1252,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz // We don't report errors here because they have been already reported in // the translation of `unapply` function. given Raise = Function.const(()) - val inputSymbol = VarSymbol(Ident("input")) + val inputSymbol = VarSymbol(Ident("input"), erasedType = N) val topmost = makeStringPrefixMatchSplit(inputSymbol.toScrut, pd.pattern) ((consumedOutput, remainingOutput, bindings) => Split.Else: makeMatchSuccess(tup(fld(consumedOutput()), fld(remainingOutput()))), failure) @@ -1267,7 +1268,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz scrut: VarSymbol, topmost: Split ): Ls[Statement] = - val fieldSymbol = TempSymbol(N, name) + val fieldSymbol = TempSymbol(N, erasedType = N, name) val decl = LetDecl(fieldSymbol, Nil) val param = Param(FldFlags.empty, scrut, N, Modulefulness.none) val paramList = PlainParamList(param :: Nil) @@ -1288,7 +1289,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz case _ => N term.getOrElse: val unapply = scoped("ucs:translation"): - val inputSymbol = VarSymbol(Ident("input")) + val inputSymbol = VarSymbol(Ident("input"), erasedType = N) val topmost = makeMatchSplit(inputSymbol.toScrut, pattern, true) ((output, bindings) => Split.Else(makeMatchSuccess(output())), failure) log(s"Translated `unapply`: ${topmost.prettyPrint}") @@ -1297,7 +1298,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz // We don't report errors here because they have been already reported in // the translation of `unapply` function. given Raise = Function.const(()) - val inputSymbol = VarSymbol(Ident("input")) + val inputSymbol = VarSymbol(Ident("input"), erasedType = N) val topmost = makeStringPrefixMatchSplit(inputSymbol.toScrut, pattern) ((consumedOutput, remainingOutput, bindings) => Split.Else: makeMatchSuccess(tup(fld(consumedOutput()), fld(remainingOutput()))), failure) diff --git a/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala b/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala index ba2a5d116b..2a19036c23 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala @@ -641,7 +641,7 @@ trait TypeDefImpl(using State) extends TypeOrTermDef: pts.flatMap(_.desugared.asParam(inUsing = inUsing).toOption).collect: case pt @ ParamTree(ident = id, spd = N) => val k = if pt.flags.mut then MutVal else ImmutVal - TermSymbol(k, symbol.asClsLike, id) + TermSymbol(k, symbol.asClsLike, id, erasedType = N) .toList lazy val allSymbols = definedSymbols ++ diff --git a/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls b/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls index 27ef775895..9f749b55ab 100644 --- a/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls +++ b/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls @@ -269,7 +269,7 @@ staged module A with :todo staged module Spread with fun f() = if [1, ..[1, 2]] is [1, ...x] then x else 0 -//│ ═══[COMPILATION ERROR] Spread parameters are not supported in staged module: Arg(Some(Lazy),SimpleRef(tmp:tmp,None)) +//│ ═══[COMPILATION ERROR] Spread parameters are not supported in staged module: Arg(Some(Lazy),SimpleRef(tmp:tmp)) //│ ═══[COMPILATION ERROR] No definition found in scope for member 'tmp' //│ > fun ctor_() = () //│ ═══[RUNTIME ERROR] Error: MLscript call unexpectedly returned `undefined`, the forbidden value. diff --git a/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls b/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls index 360940490b..feaaf5a5e4 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls @@ -29,13 +29,11 @@ x + 1 //│ lhs = x⁰ //│ rhs = Lit of IntLit of 1 //│ rest = Return of Call: -//│ fun = SimpleRef: -//│ sym = builtin:+⁰ +//│ fun = SimpleRef of builtin:+⁰ //│ argss = Ls of //│ Ls of //│ Arg: -//│ value = SimpleRef: -//│ sym = x⁰ +//│ value = SimpleRef of x⁰ //│ Arg: //│ value = Lit of IntLit of 1 //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala index fda3ab1572..3ca62ad88a 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala @@ -219,7 +219,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val nestedScp = baseScp // val nestedScp = codegen.js.Scope(S(baseScp), curCtx.outer, collection.mutable.Map.empty) // * not needed - val resSym = new TempSymbol(N, "block$res") + val resSym = new TempSymbol(N, erasedType = N, "block$res") val resNme = nestedScp.allocateName(resSym) From 5c358ee74e25ef9bbb36430a3a2b430f1e71aa2e Mon Sep 17 00:00:00 2001 From: David Mak Date: Sat, 30 May 2026 12:12:10 +0800 Subject: [PATCH 08/48] codegen: Add `ErasedType.ObjectRef` --- .../src/main/scala/hkmc2/codegen/Block.scala | 38 +++++++++++-------- .../hkmc2/semantics/ucs/Normalization.scala | 2 +- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 6b9f55bbf4..2335eef264 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -849,23 +849,25 @@ enum Case: /** A primitive type of the block IR. */ enum PrimitiveType: - case Unit, Int, Int31, Num, Str, Bool - - /** The symbol for this primitive type, if available. */ - def sym(using Ctx, State): Opt[ClassLikeSymbol] = this match - case Unit => S(summon[State].unitSymbol) - case Int => S(ctx.builtins.Int) - case Int31 => S(ctx.builtins.Int31) - case Num => S(ctx.builtins.Num) - case Str => S(ctx.builtins.Str) - case Bool => S(ctx.builtins.Bool) + case Unit, Int, Int31, Num, Str, Bool, Array + + /** The symbol for this primitive type. */ + def sym(using Ctx, State): ClassLikeSymbol = this match + case Unit => summon[State].unitSymbol + case Int => ctx.builtins.Int + case Int31 => ctx.builtins.Int31 + case Num => ctx.builtins.Num + case Str => ctx.builtins.Str + case Bool => ctx.builtins.Bool + case Array => ctx.builtins.Array - -object ErasedType: - def objectRef(using Ctx): ErasedType = ErasedType.AnyRef(rsc = false, ctx.builtins.Object) - /** A generics-erased type of the Block IR. */ enum ErasedType: + /** A reference to the Object (top) type. */ + // Implementation Note: This is not collapsed into `AnyRef` to avoid the need to pass `Elaborator.Ctx` around just to + // recover `ctx.builtins.Object` + case ObjectRef + /** * An reference to a class-like symbol. * @@ -876,13 +878,19 @@ enum ErasedType: /** An primitive type. */ case Primitive(prim: PrimitiveType) + /** The symbol for this erased type. */ + def sym(using Ctx, State): ClassLikeSymbol = this match + case ObjectRef => ctx.builtins.Object + case AnyRef(_, csym) => csym + case Primitive(prim) => prim.sym + /** Trait representing a Block IR element that has an [[`ErasedType`]]. */ trait HasErasedType: /** The [[`ErasedType`]] of this element, or `N` if the erased type is not known. */ val erasedType: Opt[ErasedType] /** Similar to `erasedType`, but coerces to the top type if the specific erased type is not known. */ - def erasedType_!(using Ctx): ErasedType = erasedType.getOrElse(ErasedType.objectRef) + def erasedType_! : ErasedType = erasedType.getOrElse(ErasedType.ObjectRef) extension (lit: Literal) def erasedType: ErasedType = lit match diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index d191aff03f..4827ffc2ce 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -367,7 +367,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C for (_, s) <- entries do LoweringCtx.loweringCtx.collectScopedSym(s) val objectSym = ctx.builtins.Object mkMatch( // checking that we have an object - Case.Cls(objectSym, BuiltinSymbol(objectSym.nme, false, false, true, false, erasedType = S(ErasedType.objectRef)).asSimpleRef), + Case.Cls(objectSym, BuiltinSymbol(objectSym.nme, false, false, true, false, erasedType = S(ErasedType.ObjectRef)).asSimpleRef), entries.foldRight(lowerSplit(tail, cont)): case ((fieldName, fieldSymbol), blk) => mkMatch( From f3bfba13c134a9f5bdd6b2be5f4bd85e20e98cff Mon Sep 17 00:00:00 2001 From: David Mak Date: Sat, 30 May 2026 12:43:59 +0800 Subject: [PATCH 09/48] codegen: Bubble `HasErasedType` up to `Result` `Call`, `Lambda`, `Select`, and `DynSelect` is left for a future commit. --- .../src/main/scala/hkmc2/codegen/Block.scala | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 2335eef264..c84956babe 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -887,7 +887,7 @@ enum ErasedType: /** Trait representing a Block IR element that has an [[`ErasedType`]]. */ trait HasErasedType: /** The [[`ErasedType`]] of this element, or `N` if the erased type is not known. */ - val erasedType: Opt[ErasedType] + def erasedType: Opt[ErasedType] /** Similar to `erasedType`, but coerces to the top type if the specific erased type is not known. */ def erasedType_! : ErasedType = erasedType.getOrElse(ErasedType.ObjectRef) @@ -902,7 +902,7 @@ extension (lit: Literal) sealed trait TrivialResult extends Result -sealed abstract class Result extends AutoLocated: +sealed abstract class Result extends AutoLocated, HasErasedType: // // * Used for debugging locations: // sealed abstract class Result extends AutoLocated with ProductWithExtraInfo: // def extraInfo: Str = toLoc.toString @@ -1010,6 +1010,16 @@ sealed abstract class Result extends AutoLocated: case Value.Lit(lit) => 0 case DynSelect(qual, fld, arrayIdx) => qual.size + fld.size + lazy val erasedType: Opt[ErasedType] = this match + case Tuple(_, _) => S(ErasedType.Primitive(PrimitiveType.Array)) + case Record(_, _) => S(ErasedType.ObjectRef) + case Instantiate(_, cls, _) => cls.targetSymbol.flatMap(_.asClsOrMod).map(ErasedType.AnyRef(rsc = false, _)) + case Value.SimpleRef(sym) => sym.erasedType + case Value.MemberRef(_, disamb: (ClassSymbol | ModuleOrObjectSymbol | TypeAliasSymbol)) => disamb.erasedType + case Value.This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => clsOrMod.erasedType + case Value.Lit(lit) => S(lit.erasedType) + case _ => N + // * TODO: refine this very loose type // type Local = LocalSymbol type Local = Symbol @@ -1056,7 +1066,7 @@ case class Select(qual: Path, name: Tree.Ident)(val symbol: Opt[DefinitionSymbol case class DynSelect(qual: Path, fld: Path, arrayIdx: Bool) extends Path -enum Value extends Path with HasErasedType with ProductWithExtraInfo: +enum Value extends Path with ProductWithExtraInfo: case SimpleRef(sym: LocalVarSymbol | BuiltinSymbol) /** * @param disamb The symbol disambiguating the definition that the reference refers to. @@ -1064,14 +1074,6 @@ enum Value extends Path with HasErasedType with ProductWithExtraInfo: case MemberRef(bms: BlockMemberSymbol, disamb: DefinitionSymbol[?]) case This(sym: InnerSymbol) case Lit(lit: Literal) - - /** The [[`ErasedType`]] of this value. */ - val erasedType: Opt[ErasedType] = this match - case SimpleRef(sym) => sym.erasedType - case MemberRef(_, disamb: (ClassSymbol | ModuleOrObjectSymbol | TypeAliasSymbol)) => disamb.erasedType - case This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => clsOrMod.erasedType - case Lit(lit) => S(lit.erasedType) - case _ => N override def extraInfo(using DebugPrinter): Str = this match case MemberRef(bms, disamb) => s"disamb=${disamb.showAsPlain}" From 7112f5403758aad5a171098916f0dbe3f36031df Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 11:37:02 +0800 Subject: [PATCH 10/48] codegen: Tighten the erased type of some symbols --- .../src/main/scala/hkmc2/codegen/Block.scala | 3 +-- .../scala/hkmc2/codegen/HandlerLowering.scala | 2 +- .../main/scala/hkmc2/codegen/Lowering.scala | 26 ++++++++++--------- .../codegen/ReflectionInstrumenter.scala | 2 +- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 4 +-- .../hkmc2/semantics/ups/SplitCompiler.scala | 4 +-- 6 files changed, 21 insertions(+), 20 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index c84956babe..706cabdf8d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -706,8 +706,7 @@ object ValDefn: annotations: Ls[Annot], )(using State) : ValDefn = - // TODO(Derppening): We can probably use the erasedType from `rhs` once Path implements `HasErasedType` - ValDefn(tsym = TermSymbol(k, owner, Tree.Ident(sym.nme), erasedType = N), sym = sym, rhs = rhs)(configOverride, annotations) + ValDefn(tsym = TermSymbol(k, owner, Tree.Ident(sym.nme), erasedType = rhs.erasedType), sym, rhs)(configOverride, annotations) /* diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index d391891302..21c9a6a6a6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -537,7 +537,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, .flatMap: (sym, idx) => List(intLit(idx), Value.Lit(Tree.StrLit(sym.nme))) .map(_.asArg) - val debugInfoSym = freshTmp(erasedType = N, s"$debugNme$$debugInfo") + val debugInfoSym = freshTmp(erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), s"$debugNme$$debugInfo") // TODO: properly support spread argument by calculating the correct length. val rtArgLists = intLit(fun.params.length) :: fun.params.flatMap: pl => intLit(pl.params.length) :: pl.params.map(p => p.sym.asSimpleRef) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index f80acbf5a0..e9ba481491 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -413,8 +413,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): acc.reverse match case Nil => lowerRemainingCalls(fr, args, remainingArgss, isTailCall, loc)(k) case acc: NELs[Ls[Arg]] => - val tmp = loweringCtx.registerTempSymbol(N, erasedType = N, "baseCall") val call = Call(fr, acc)(isMlsFun, true, isTailCall).withLoc(loc) + val tmp = loweringCtx.registerTempSymbol(N, erasedType = call.erasedType, "baseCall") Assign(tmp, call, lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall, loc)(k)) case (_ :: _, Nil) => k(Call(fr, acc.reverse.ne_!)(isMlsFun, true, isTailCall).withLoc(loc)) @@ -433,7 +433,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): remainingArgss match case Nil => k(call) case args :: remainingArgss => - val tmp = loweringCtx.registerTempSymbol(N, erasedType = N, "callPrefix") + val tmp = loweringCtx.registerTempSymbol(N, erasedType = call.erasedType, "callPrefix") Assign(tmp, call, lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall, loc)(k)) @@ -457,8 +457,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case (Nil, Nil) => k(buildInstantiate(acc.reverse)) case (Nil, args :: remainingArgss) => - val tmp = loweringCtx.registerTempSymbol(N, erasedType = N, "baseInst") - Assign(tmp, buildInstantiate(acc.reverse), + val inst = buildInstantiate(acc.reverse) + val tmp = loweringCtx.registerTempSymbol(N, erasedType = inst.erasedType, "baseInst") + Assign(tmp, inst, lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, isTailCall = false, N)(k)) case (remainingParamss, Nil) => // * Eta-expand missing argument lists by creating lambdas for each remaining param list. @@ -496,8 +497,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Nil => k(buildInstantiate(as :: Nil)) case remainingArgss => - val tmp = loweringCtx.registerTempSymbol(N, erasedType = N, "baseInst") - Assign(tmp, buildInstantiate(as :: Nil), + val inst = buildInstantiate(as :: Nil) + val tmp = loweringCtx.registerTempSymbol(N, erasedType = inst.erasedType, "baseInst") + Assign(tmp, inst, lowerRemainingCalls(tmp.asSimpleRef, remainingArgss.head, remainingArgss.tail, isTailCall = false, N)(k)) else zipArgs(ctorParamLists, args, Nil) @@ -660,7 +662,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): term_nonTail(body)(r => Assign(result, r, Break(label))) else val bodyResult = loweringCtx.registerTempSymbol(N, erasedType = N, "labelBodyResult") - val isContinue = loweringCtx.registerTempSymbol(N, erasedType = N, "labelContinueDispatch") + val isContinue = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "labelContinueDispatch") term_nonTail(body): r => Assign( bodyResult, @@ -1110,7 +1112,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): def rec(ps: Ls[LocalSymbol & NamedSymbol], ds: Ls[Path])(k: Result => Block)(using LoweringCtx): Block = ps match case Nil => quote(body): r => val l = loweringCtx.registerTempSymbol(N, erasedType = N) - val arr = loweringCtx.registerTempSymbol(N, erasedType = N, "arr") + val arr = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "arr") Assign( arr, Tuple(mut = false, ds.reverse.map(_.asArg)), @@ -1125,7 +1127,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case App(lhs, Tup(rhs)) => quote(lhs): r1 => def rec(es: Ls[Elem], xs: Ls[Path])(k: Result => Block): Block = es match case Nil => - val arrSym = loweringCtx.registerTempSymbol(N, erasedType = N, "arr") + val arrSym = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "arr") Assign( arrSym, Tuple(mut = false, xs.reverse.map(_.asArg)), @@ -1150,7 +1152,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): loweringCtx.collectScopedSyms(sym) setupSymbol(sym){r1 => val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N, erasedType = N) - val arrSym = loweringCtx.registerTempSymbol(N, erasedType = N, "arr") + val arrSym = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "arr") blockBuilder.assign(l1, r1) .chain(b => setupTerm("Ref", l1.asSimpleRef :: Nil)(r => Assign(sym, r, b))) .chain(b => quote(rhs)(r2 => Assign(l2, r2, b))) @@ -1227,7 +1229,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): if fsr.isEmpty then Begin(b, k(asr.reverse)) else - val rcdSym = loweringCtx.registerTempSymbol(N, erasedType = N, "rcd") + val rcdSym = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.ObjectRef), "rcd") Begin( b, Assign( @@ -1261,7 +1263,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val lamDef = FunDefn.withFreshSymbol(N, lamSym, params :: Nil, body)(configOverride = N, annotations = lam.annot) Define(lamDef, k(lamDef.asPath)) case r => - val l = loweringCtx.registerTempSymbol(N, erasedType = N) + val l = loweringCtx.registerTempSymbol(N, erasedType = r.erasedType) Assign(l, r, k(l.asSimpleRef)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala index dbb1efac51..1182708700 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala @@ -54,7 +54,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n def assign(res: Result, symName: Str = "tmp")(k: Path => Block): Block = // TODO: skip assignment if res: Path? - val sym = new TempSymbol(N, erasedType = N, symName) + val sym = new TempSymbol(N, erasedType = res.erasedType, symName) Scoped(Set(sym), Assign(sym, res, k(sym.asSimpleRef))) def tuple(elems: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 9f61aba405..c0a593f808 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -422,7 +422,7 @@ class TailRecOpt(using State, TL, Raise): val paramList = ogParamList.params val restParam = ogParamList.restParam - val tupleSym = TempSymbol(N, erasedType = N, "argList") + val tupleSym = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "argList") // We can safely remove all of the symbols from this parameter list from `assignedSyms` at this stage, // because the RHS of every parameter will be computed when spreading them in the tuple, which happens @@ -438,7 +438,7 @@ class TailRecOpt(using State, TL, Raise): // If the rest param exists, append a slice val (initialBlk: (Block => Block), pathList: List[Path]) = if restParam.isDefined then - val sliceResSym = TempSymbol(N, erasedType = N, "sliceRes") + val sliceResSym = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "sliceRes") // runtime.Tuple.slice(tupleSym, paramList.length, 0) val sliceRes = Call( State.runtimeSymbol.asSimpleRef diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala index 593bb7cf8e..cfb42f4c17 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala @@ -731,7 +731,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val trailingSize = trailing.size val (trailSubScrutinees, makeConsequent0) = trailing.folded(makeConsequent): index => scrutinee.getTupleLastSubScrutinee(trailingSize - index) - val spreadSubScrutinee = TempSymbol(N, erasedType = N, "middleElements") + val spreadSubScrutinee = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "middleElements") val makeConsequent1: MakeConsequent = (outerOutput, outerBindings) => makeMatchSplit(spreadSubScrutinee.toScrut, spread, false)( (spreadOutput, spreadBindings) => makeConsequent0( @@ -966,7 +966,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz makeStringPrefixMatchSplit(scrutinee, left)( (leftOutput, leftRemains, leftBindings) => makeMatchSplit(scrutinee, right)( (rightOutput, rightBindings) => - val productSymbol = TempSymbol(N, erasedType = N, "product") + val productSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "product") val productTerm = tup(leftOutput() |> fld, rightOutput() |> fld) Split.Let(productSymbol, productTerm, makeConsequent( productSymbol.toScrut, leftRemains, leftBindings ++ rightBindings)), From 6dc04ca88bd38b30e4e4af3a834bdb4423cc81c0 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 11:46:30 +0800 Subject: [PATCH 11/48] codegen/js: Fix comment alignment --- hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index bd6b9b887e..b86ad7ec23 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -317,7 +317,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: val scrutSym = scrut.map(_.sym) b match case Match( - scrut_ @ Value.SimpleRef(scrutSym_), // The scrutinee is a ref. + scrut_ @ Value.SimpleRef(scrutSym_), // The scrutinee is a ref. (Case.Lit(Tree.IntLit(curVal_)), b) :: Nil, // There is only one case matching an int literal. S(End(_)) | N, rest // Default case exists and does nothing. ) From 1e0799caca6b71528854083f78d6db255b4861f3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 13:38:09 +0800 Subject: [PATCH 12/48] semantics: Drop `using State` from `NoSymbol` --- hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index 78cc15788e..1599ca8d39 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -155,7 +155,7 @@ end Symbol // * Used, eg, as the Assign receiver of intermediate computations whose result is not used -final class NoSymbol(using State) extends MaybeSymbol: +final class NoSymbol extends MaybeSymbol: def nme: Str = "‹no symbol›" override def toString: Str = nme From 2d6d2833fdb9e5e3cc6cfd5a5a026e1ab368c60a Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 13:38:29 +0800 Subject: [PATCH 13/48] codegen: Fold `ObjectRef` as `AnyRef(_, NoSymbol)` --- .../src/main/scala/hkmc2/codegen/Block.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 5c0c506459..182fedfbe1 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -849,28 +849,28 @@ enum PrimitiveType: case Str => ctx.builtins.Str case Bool => ctx.builtins.Bool case Array => ctx.builtins.Array + +object ErasedType: + def ObjectRef: ErasedType.AnyRef = AnyRef(rsc = false, NoSymbol()) /** A generics-erased type of the Block IR. */ enum ErasedType: - /** A reference to the Object (top) type. */ - // Implementation Note: This is not collapsed into `AnyRef` to avoid the need to pass `Elaborator.Ctx` around just to - // recover `ctx.builtins.Object` - case ObjectRef - /** * An reference to a class-like symbol. * + * If `csym` is `NoSymbol`, this represents the top type (`Object`). + * * - `rsc` is true if this reference is a resource class. */ - case AnyRef(rsc: Bool, csym: ClassLikeSymbol) + case AnyRef(rsc: Bool, csym: ClassLikeSymbol | NoSymbol) /** An primitive type. */ case Primitive(prim: PrimitiveType) /** The symbol for this erased type. */ def sym(using Ctx, State): ClassLikeSymbol = this match - case ObjectRef => ctx.builtins.Object - case AnyRef(_, csym) => csym + case AnyRef(_, csym: ClassLikeSymbol) => csym + case AnyRef(_, _: NoSymbol) => ctx.builtins.Object case Primitive(prim) => prim.sym /** Trait representing a Block IR element that has an [[`ErasedType`]]. */ From 8ba7fe18d01bdce36122cde26b28b3ae7e9dcd42 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 15:41:12 +0800 Subject: [PATCH 14/48] codegen: Implement printing of erased types Refinement of types during Lowering is implemented later. --- .../main/scala/hkmc2/codegen/Printer.scala | 25 ++++-- .../src/main/scala/hkmc2/semantics/Term.scala | 2 + .../src/test/mlscript/codegen/ErasedType.mls | 76 +++++++++++++++++++ .../test/scala/hkmc2/JSBackendDiffMaker.scala | 3 + .../src/test/scala/hkmc2/MLsDiffMaker.scala | 5 ++ 5 files changed, 103 insertions(+), 8 deletions(-) create mode 100644 hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index c304a55cb5..182b0ff5f0 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -30,6 +30,15 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case S(str) => str case N => summon[SymbolPrinter].printSymbol(l) + def print(et: ErasedType)(using Scope): Document = et match + case ErasedType.AnyRef(rsc, csym: ClassLikeSymbol) => doc"${if rsc then "rsc " else ""}${print(csym)}" + case ErasedType.AnyRef(rsc, _: NoSymbol) => doc"${if rsc then "rsc " else ""}Object" + case ErasedType.Primitive(prim) => doc"${prim.toString}" + + def erasedTypeAnnot(x: HasErasedType)(using Scope): Document = + if !summon[ShowCfg].showErasedTypes then doc"" + else doc": ${x.erasedType.fold(doc"?")(print)}" + def print(blk: Block)(using Scope): Document = blk match case Match(scrut, arms, dflt, rest) => def case_doc(c: Case) = c match @@ -59,7 +68,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case Assign(_: NoSymbol, rhs, rest) => doc"do ${print(rhs)}; # ${print(rest)}" case Assign(lhs: (LocalVarSymbol | TermSymbol), rhs, rest) => - doc"set ${print(lhs)} = ${print(rhs)}; # ${print(rest)}" + doc"set ${print(lhs)}${erasedTypeAnnot(lhs)} = ${print(rhs)}; # ${print(rest)}" case asf @ AssignField(lhs, nme, rhs, rest) => doc"set ${print(lhs)}.${showMemberSymbol(nme.name, asf.symbol)} = ${print(rhs)}; # ${print(rest)}" case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => @@ -97,8 +106,8 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): ctor: Block, ctorSym: Opt[TermSymbol], )(using Scope): Document = - val privFields = privateFields.map(x => doc"private val ${print(x)};").mkDocument(sep = doc" # ") - val pubFields = publicFields.map(x => doc"val ${print(x._1)};").mkDocument(sep = doc" # ") + val privFields = privateFields.map(x => doc"private val ${print(x)}${erasedTypeAnnot(x)};").mkDocument(sep = doc" # ") + val pubFields = publicFields.map(x => doc"val ${print(x._1)}${erasedTypeAnnot(x._2)};").mkDocument(sep = doc" # ") val docPrivFlds = if privateFields.isEmpty then doc"" else doc" # ${privFields}" val docPubFlds = if publicFields.isEmpty then doc"" else doc" # ${pubFields}" val docPreCtor = preCtor match @@ -127,8 +136,8 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): paramss .map: pl => val allParams = - pl.params.map(x => scope.allocateName(x.sym)) ++ - pl.restParam.map(x => "..." + scope.allocateName(x.sym)) + pl.params.map(x => doc"${scope.allocateName(x.sym)}${erasedTypeAnnot(x.sym)}") ++ + pl.restParam.map(x => doc"...${scope.allocateName(x.sym)}${erasedTypeAnnot(x.sym)}") allParams.mkDocument("(", ", ", ")") .mkDocument("") @@ -140,7 +149,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): val docStaged = if fun.isStaged then doc"staged " else doc"" doc"${docStaged}fun ${print(dSym)}${docParams} ${bracedbk(docBody)}" case ValDefn(tsym, sym, rhs) => - doc"val ${print(tsym)} = ${print(rhs)}" + doc"val ${print(tsym)}${erasedTypeAnnot(tsym)} = ${print(rhs)}" case cls @ ClsLikeDefn(own, isym, sym, ctorSym, k, paramsOpt, auxParams, parentSym, methods, privateFields, publicFields, preCtor, ctor, mod, bufferable) => scope.nest.givenIn: @@ -195,8 +204,8 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case Lambda(params, body) => scope.nest.givenIn: val allParams = - params.params.map(x => scope.allocateName(x.sym)) ++ - params.restParam.map(x => "..." + scope.allocateName(x.sym)) + params.params.map(x => doc"${scope.allocateName(x.sym)}${erasedTypeAnnot(x.sym)}") ++ + params.restParam.map(x => doc"...${scope.allocateName(x.sym)}${erasedTypeAnnot(x.sym)}") val docParams = allParams.mkDocument("(", ", ", ")") doc"$docParams => ${bracedbk(print(body))}" case Tuple(mut, elems) => diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala index 19f9456b68..8ab151a3a2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala @@ -575,6 +575,7 @@ extension (self: Blk) case class ShowCfg( + showErasedTypes: Bool, showExpansionMappings: Bool, showFlowSymbols: Bool, debug: Bool, @@ -586,6 +587,7 @@ end ShowCfg object ShowCfg: // * For use when displaying things for internal use (not for end users) val internal = ShowCfg( + showErasedTypes = true, showFlowSymbols = true, showExpansionMappings = false, debug = false, diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls new file mode 100644 index 0000000000..aa7e581f2a --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -0,0 +1,76 @@ +:sir + +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ import Predef; end + +// Literals: Int / Str / Num / Bool. +:siret +let i = 1 +let f = 1.5 +let s = "str" +let b = true +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let i⁰, f⁰, s⁰, b⁰; set i⁰: ? = 1; set f⁰: ? = 1.5; set s⁰: ? = "str"; set b⁰: ? = true; end + +// Tuple -> Array; record -> Object (the `NoSymbol` top case). +:siret +let t = [i, s] +let r = {x: i, y: s} +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let t⁰, r⁰, x⁰, y⁰; +//│ set t⁰: ? = [i⁰, s⁰]; +//│ set x⁰: ? = i⁰; +//│ set y⁰: ? = s⁰; +//│ set r⁰: ? = { "x": x⁰, "y": y⁰ }; +//│ end + +// Instantiation -> the class symbol. +:siret +class Foo(x: Int) +let foo = new Foo(123) +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let Foo⁰, foo⁰; +//│ define Foo⁰ as class Foo² { +//│ private val x¹: ?; +//│ constructor Foo¹(x: ?) { set Foo².this.x¹ = x; end } +//│ }; +//│ set foo⁰: ? = new Foo²(123); +//│ end + +// Call return / field selection +:siret +fun mk(a) = new Foo(a) +let made = mk(1) +let sel = made.x +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let mk⁰, made⁰, sel⁰, selRes; +//│ define mk⁰ as fun mk¹(a: ?) { +//│ return new Foo²(a) +//│ }; +//│ set made⁰: ? = mk¹(1); +//│ set selRes: ? = made⁰.x﹖; +//│ do made⁰.x$__checkNotMethod﹖; +//│ match selRes +//│ undefined => +//│ throw new globalThis⁰.Error﹖("Access to required field 'x' yielded 'undefined'") +//│ else +//│ set sel⁰: ? = selRes; +//│ end +//│ end + +// Parameter annotations +:siret +fun add(x: Int, y: Int) = x + y +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let add⁰; define add⁰ as fun add¹(x: ?, y: ?) { return +⁰(x, y) }; end + +// Optimized IR with erased types: surfaces where a type relaxes to `: ?` +// through a transform (the relaxation gap `:soir` makes visible). +:soir +:siret +fun id(a) = a +id(new Foo(1)) +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let id⁰, tmp; define id⁰ as fun id¹(a: ?) { return a }; set tmp: Foo² = new Foo²(1); return id¹(tmp) +//│ ——————————————| Optimized IR |—————————————————————————————————————————————————————————————————————— +//│ let id⁰; define id⁰ as fun id¹(a: ?) { return a }; return new Foo²(1) diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala index ef1c403d0d..5d0e63e33c 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala @@ -108,6 +108,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: lazy val blockPrinter = given ShowCfg = ShowCfg( + showErasedTypes = showIRErasedTypes.isSet, showExpansionMappings = false, showFlowSymbols = true, debug = debug.isSet, @@ -162,6 +163,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: if showIR.isSet || showIRLines.isSet then given ShowCfg = ShowCfg( + showErasedTypes = showIRErasedTypes.isSet, showExpansionMappings = false, showFlowSymbols = true, debug = debug.isSet, @@ -200,6 +202,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: if showOptimizedIR.isSet then outputSeparator("Optimized IR") given ShowCfg = ShowCfg( + showErasedTypes = showIRErasedTypes.isSet, showExpansionMappings = false, showFlowSymbols = true, debug = debug.isSet, diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala index cd3d948f97..57de926f9a 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala @@ -46,6 +46,10 @@ abstract class MLsDiffMaker extends DiffMaker: val showOptimizedTree = NullaryCommand("olot") val debugOptimizations = NullaryCommand("dopt") val noOptimizations = NullaryCommand("noOpt") + val showIRErasedTypes = NullaryCommand("siret", () => + if showIR.isUnset && showOptimizedIR.isUnset then + output("Option ':siret' only has an effect if ':sir' or ':soir' is also set") + ) val showContext = NullaryCommand("ctx") val parseOnly = NullaryCommand("parseOnly") val funcToCls = NullaryCommand("ftc") @@ -453,6 +457,7 @@ abstract class MLsDiffMaker extends DiffMaker: if showFlows.isSet then import semantics.ShowCfg given ShowCfg = ShowCfg( + showErasedTypes = showIRErasedTypes.isSet, showExpansionMappings = true, showFlowSymbols = true, debug = debug.isSet, From 0c271c99b390a2fb846e2649b6cb44832df318eb Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 16:16:09 +0800 Subject: [PATCH 15/48] codegen: Implement `HasRefinableErasedType` --- .../src/main/scala/hkmc2/codegen/Block.scala | 15 +++++++++++++++ .../src/main/scala/hkmc2/semantics/Symbol.scala | 10 +++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 182fedfbe1..162a1f3ef3 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -13,6 +13,7 @@ import syntax.{Literal, Tree, SpreadKind, Keyword} import semantics.* import semantics.Term.* import sem.Elaborator.{Ctx, State, ctx} +import sourcecode.{FileName, Line} /* Important design notes. @@ -881,6 +882,20 @@ trait HasErasedType: /** Similar to `erasedType`, but coerces to the top type if the specific erased type is not known. */ def erasedType_! : ErasedType = erasedType.getOrElse(ErasedType.ObjectRef) +/** A [[`HasErasedType`]] that can have its erased type refined post-construction. */ +trait HasRefinableErasedType extends HasErasedType: + // Implementation Note: Provided for overriding classes to implement `erasedType` directly as an `override var` + def erasedType_=(newType: Opt[ErasedType]): Unit + + /** Refines the erased type if it was not previously refined, but allowing for idempotent refinements to the same + * type. + * + * A soft assert will be raised if the erased type was already refined to a different type. + */ + def refineErasedType(newType: ErasedType)(using Line, FileName, Raise): Unit = + softAssert(erasedType.forall(_ == newType), s"Cannot refine already-refined erased type $erasedType to $newType") + if erasedType.isEmpty then erasedType = S(newType) + extension (lit: Literal) def erasedType: ErasedType = lit match case Tree.UnitLit(_) => ErasedType.Primitive(PrimitiveType.Unit) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index 1599ca8d39..eb329eb8dc 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -12,6 +12,7 @@ import Elaborator.State import Tree.Ident import hkmc2.codegen.{ErasedType, HasErasedType, erasedType} import hkmc2.utils.SymbolSubst +import hkmc2.codegen.HasRefinableErasedType sealed abstract class MaybeSymbol: @@ -248,7 +249,10 @@ class InstSymbol(val origin: Symbol)(using State) extends LocalSymbol: def subst(using sub: SymbolSubst): InstSymbol = sub.mapInstSym(this) -class VarSymbol(val id: Ident, override val erasedType: Opt[ErasedType])(using State) extends LocalVarSymbol(id.name) with NamedSymbol: +class VarSymbol(val id: Ident, override var erasedType: Opt[ErasedType])(using State) + extends LocalVarSymbol(id.name) + with HasRefinableErasedType + with NamedSymbol: val name: Str = id.name override def toLoc: Opt[Loc] = id.toLoc // override def toString: Str = s"$name@$uid" @@ -337,10 +341,10 @@ sealed abstract class MemberSymbol(using State) extends Symbol: def subst(using SymbolSubst): MemberSymbol -class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.Ident, override val erasedType: Opt[ErasedType])(using State) +class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.Ident, override var erasedType: Opt[ErasedType])(using State) extends MemberSymbol with DefinitionSymbol[TermDefinition] - with HasErasedType + with HasRefinableErasedType with NamedSymbol: def nme: Str = id.name def name: Str = nme From 1620a30d168a435dbf96808ff04f595e561e836c Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 16:55:55 +0800 Subject: [PATCH 16/48] codegen: Add erased type refinement to parameters --- hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala | 7 +++++++ hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala | 7 +++++++ hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls | 2 +- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 162a1f3ef3..089ccb9d6f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -853,6 +853,13 @@ enum PrimitiveType: object ErasedType: def ObjectRef: ErasedType.AnyRef = AnyRef(rsc = false, NoSymbol()) + + /** Maps a [[`ClassLikeSymbol`]] into the canonical [[`ErasedType`]]. */ + def fromClsLikeSymbol(csym: ClassLikeSymbol, rsc: Bool)(using Ctx, State): ErasedType = + PrimitiveType.values.find(_.sym === csym) match + case _ if csym === ctx.builtins.Object => ObjectRef + case S(prim) => ErasedType.Primitive(prim) + case _ => ErasedType.AnyRef(rsc, csym) /** A generics-erased type of the Block IR. */ enum ErasedType: diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 632cc39d8b..fd537f40be 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -1377,8 +1377,15 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val scopedSyms = loweringCtx.getCollectedSym Scoped(scopedSyms, body) + /** Erases a type-annotated term to an [[`ErasedType`]]. */ + private def eraseSign(sign: Term): Opt[ErasedType] = + sign.symbol.flatMap(_.asClsOrMod).map(sym => ErasedType.fromClsLikeSymbol(sym, rsc = false)) + def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str]) (using LoweringCtx): (List[ParamList], Block) = + paramLists.foreach: pl => + pl.params.foreach: p => + p.sign.flatMap(eraseSign).foreach(p.sym.refineErasedType) val scopedBody = inScopedBlock(returnedTerm(bodyTerm)) (paramLists, scopedBody) diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index aa7e581f2a..d4bde41533 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -62,7 +62,7 @@ let sel = made.x :siret fun add(x: Int, y: Int) = x + y //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let add⁰; define add⁰ as fun add¹(x: ?, y: ?) { return +⁰(x, y) }; end +//│ let add⁰; define add⁰ as fun add¹(x: Int, y: Int) { return +⁰(x, y) }; end // Optimized IR with erased types: surfaces where a type relaxes to `: ?` // through a transform (the relaxation gap `:soir` makes visible). From fe4b63bb10a0249fa44d8b123fac7057905447c3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 17:15:30 +0800 Subject: [PATCH 17/48] codegen: Implement erased type refinement for class fields --- .../main/scala/hkmc2/codegen/Lowering.scala | 9 ++++++++ .../src/test/mlscript/codegen/ErasedType.mls | 23 +++++++++++++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index fd537f40be..29ac4cbfb5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -290,6 +290,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ) case _ => _defn reportAnnotations(defn, defn.extraAnnotations) + (defn.paramsOpt.iterator ++ defn.auxParams.iterator).flatMap(_.params).foreach(refineClassParam) val bufferableAnnots = defn.annotations.flatMap: case Annot.Trm(trm: SynthSel) => if trm.sym.contains(ctx.builtins.annotations.buffered) then @@ -1381,6 +1382,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): private def eraseSign(sign: Term): Opt[ErasedType] = sign.symbol.flatMap(_.asClsOrMod).map(sym => ErasedType.fromClsLikeSymbol(sym, rsc = false)) + private def refineClassParam(p: Param): Unit = + p.sign.flatMap(eraseSign).foreach: et => + p.sym.refineErasedType(et) + p.fldSym.foreach: + case fld: TermSymbol => fld.refineErasedType(et) + case bms: BlockMemberSymbol => bms.tsym.foreach(_.refineErasedType(et)) + case _ => + def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str]) (using LoweringCtx): (List[ParamList], Block) = paramLists.foreach: pl => diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index d4bde41533..01e8048493 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -31,12 +31,25 @@ let foo = new Foo(123) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— //│ let Foo⁰, foo⁰; //│ define Foo⁰ as class Foo² { -//│ private val x¹: ?; -//│ constructor Foo¹(x: ?) { set Foo².this.x¹ = x; end } +//│ private val x¹: Int; +//│ constructor Foo¹(x: Int) { set Foo².this.x¹ = x; end } //│ }; //│ set foo⁰: ? = new Foo²(123); //│ end +// Instantiation on class with `val` param -> the class symbol. +:siret +class Foo(val x: Str) +let foo = new Foo(123) +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let Foo³, foo¹; +//│ define Foo³ as class Foo⁵ { +//│ val x²: Str; +//│ constructor Foo⁴(x: Str) { define x² as val x³: Str = x; end } +//│ }; +//│ set foo¹: ? = new Foo⁵(123); +//│ end + // Call return / field selection :siret fun mk(a) = new Foo(a) @@ -45,7 +58,7 @@ let sel = made.x //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— //│ let mk⁰, made⁰, sel⁰, selRes; //│ define mk⁰ as fun mk¹(a: ?) { -//│ return new Foo²(a) +//│ return new Foo⁵(a) //│ }; //│ set made⁰: ? = mk¹(1); //│ set selRes: ? = made⁰.x﹖; @@ -71,6 +84,6 @@ fun add(x: Int, y: Int) = x + y fun id(a) = a id(new Foo(1)) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let id⁰, tmp; define id⁰ as fun id¹(a: ?) { return a }; set tmp: Foo² = new Foo²(1); return id¹(tmp) +//│ let id⁰, tmp; define id⁰ as fun id¹(a: ?) { return a }; set tmp: Foo⁵ = new Foo⁵(1); return id¹(tmp) //│ ——————————————| Optimized IR |—————————————————————————————————————————————————————————————————————— -//│ let id⁰; define id⁰ as fun id¹(a: ?) { return a }; return new Foo²(1) +//│ let id⁰; define id⁰ as fun id¹(a: ?) { return a }; return new Foo⁵(1) From 60d531bd9f49ba7ee8366b7d795ae1bb52786481 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 17:18:28 +0800 Subject: [PATCH 18/48] difftest: Add more padding to test cases in `ErasedType.mls` --- hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index 01e8048493..d567b18498 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -3,6 +3,7 @@ //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— //│ import Predef; end + // Literals: Int / Str / Num / Bool. :siret let i = 1 @@ -12,6 +13,7 @@ let b = true //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— //│ let i⁰, f⁰, s⁰, b⁰; set i⁰: ? = 1; set f⁰: ? = 1.5; set s⁰: ? = "str"; set b⁰: ? = true; end + // Tuple -> Array; record -> Object (the `NoSymbol` top case). :siret let t = [i, s] @@ -24,6 +26,7 @@ let r = {x: i, y: s} //│ set r⁰: ? = { "x": x⁰, "y": y⁰ }; //│ end + // Instantiation -> the class symbol. :siret class Foo(x: Int) @@ -37,6 +40,7 @@ let foo = new Foo(123) //│ set foo⁰: ? = new Foo²(123); //│ end + // Instantiation on class with `val` param -> the class symbol. :siret class Foo(val x: Str) @@ -50,6 +54,7 @@ let foo = new Foo(123) //│ set foo¹: ? = new Foo⁵(123); //│ end + // Call return / field selection :siret fun mk(a) = new Foo(a) @@ -71,12 +76,14 @@ let sel = made.x //│ end //│ end + // Parameter annotations :siret fun add(x: Int, y: Int) = x + y //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— //│ let add⁰; define add⁰ as fun add¹(x: Int, y: Int) { return +⁰(x, y) }; end + // Optimized IR with erased types: surfaces where a type relaxes to `: ?` // through a transform (the relaxation gap `:soir` makes visible). :soir From adfc743ac34159fced65993a49494646d844364c Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 18:03:38 +0800 Subject: [PATCH 19/48] [WIP] codegen/wasm: Implement `ErasedType.wasmType` --- .../scala/hkmc2/codegen/wasm/text/WatBuilder.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index efb984f6c8..0e2966da34 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -26,6 +26,17 @@ extension (instr: FoldedInstr) private def mnemonicPrefix: Opt[Str] = instr.mnemonic.split('.').optionUnless(_.size == 1).map(_.head) +extension (et: ErasedType) + /** Returns the corresponding Wasm type for this [[`ErasedType`]]. */ + private def wasmType(using Ctx): Opt[RefType] = + import Ctx.ctx + et match + case ErasedType.Primitive(PrimitiveType.Int | PrimitiveType.Int31 | PrimitiveType.Bool) => + S(RefType.i31ref) + case ErasedType.AnyRef(_, csym: ClassLikeSymbol) => + csym.asBlkMember.flatMap(ctx.getType).map(RefType(_, nullable = false)) + case _ => N + object WatBuilder: /** The maximum number of characters taken to be part of the identifier asscoiated with string constants. */ val StringConstantIdentMaxLength = 16 From dd94b469170cedc0bc17993a3d227387438a82a7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 23:26:59 +0800 Subject: [PATCH 20/48] WIP: Fixup `refineErasedType` implementation and docstring --- .../src/main/scala/hkmc2/codegen/Block.scala | 8 +-- .../test/mlscript/basics/MultiParamLists.mls | 50 +++++++++++++++++++ .../src/test/mlscript/ctx/ClassCtxParams.mls | 14 +++++- .../src/test/mlscript/ctx/ExplicitlySpec.mls | 7 ++- .../test/mlscript/invalml/InvalMLCodeGen.mls | 10 ++++ 5 files changed, 80 insertions(+), 9 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 089ccb9d6f..5c294f8985 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -894,13 +894,9 @@ trait HasRefinableErasedType extends HasErasedType: // Implementation Note: Provided for overriding classes to implement `erasedType` directly as an `override var` def erasedType_=(newType: Opt[ErasedType]): Unit - /** Refines the erased type if it was not previously refined, but allowing for idempotent refinements to the same - * type. - * - * A soft assert will be raised if the erased type was already refined to a different type. - */ + /** Refines the erased type, or raises a soft assertion if the type was already previously refined. */ def refineErasedType(newType: ErasedType)(using Line, FileName, Raise): Unit = - softAssert(erasedType.forall(_ == newType), s"Cannot refine already-refined erased type $erasedType to $newType") + softAssert(erasedType.isEmpty, s"Cannot refine already-refined erased type $erasedType to $newType") if erasedType.isEmpty then erasedType = S(newType) extension (lit: Literal) diff --git a/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls b/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls index 3aa06fd0e9..ed1c58f075 100644 --- a/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls +++ b/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls @@ -7,6 +7,11 @@ fun f(n1: Int): Int = n1 //│ ————————————| JS (unsanitized) |———————————————————————————————————————————————————————————————————— //│ let f; f = function f(n1) { return n1 }; +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— f(42) @@ -22,6 +27,16 @@ f(42) fun f(n1: Int)(n2: Int): Int = (10 * n1 + n2) //│ ————————————| JS (unsanitized) |———————————————————————————————————————————————————————————————————— //│ let f1; f1 = function f(n1) { return (n2) => { let tmp; tmp = 10 * n1; return tmp + n2 } }; +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— // TODO compile this to @@ -47,6 +62,21 @@ fun f(n1: Int)(n2: Int)(n3: Int): Int = 10 * (10 * n1 + n2) + n3 //│ } //│ } //│ }; +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— f(4)(2)(0) @@ -73,6 +103,26 @@ fun f(n1: Int)(n2: Int)(n3: Int)(n4: Int): Int = 10 * (10 * (10 * n1 + n2) + n3) //│ } //│ } //│ }; +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— f(3)(0)(3)(1) diff --git a/hkmc2/shared/src/test/mlscript/ctx/ClassCtxParams.mls b/hkmc2/shared/src/test/mlscript/ctx/ClassCtxParams.mls index cb6e047e12..3d17ac4389 100644 --- a/hkmc2/shared/src/test/mlscript/ctx/ClassCtxParams.mls +++ b/hkmc2/shared/src/test/mlscript/ctx/ClassCtxParams.mls @@ -115,16 +115,26 @@ class Foo(using T) with //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "Foo", [null]]; //│ }); +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1387) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1387': Cannot refine already-refined erased type Some(AnyRef(false,class:T)) to AnyRef(false,class:T) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1389) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1389': Cannot refine already-refined erased type Some(AnyRef(false,class:T)) to AnyRef(false,class:T) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— :todo :sjs x => x.Foo#S //│ ╔══[COMPILATION ERROR] Class 'Foo' does not contain member 'S'. -//│ ║ l.122: x => x.Foo#S +//│ ║ l.132: x => x.Foo#S //│ ╙── ^ //│ ╔══[COMPILATION ERROR] Cannot query instance of type T for call: -//│ ║ l.122: x => x.Foo#S +//│ ║ l.132: x => x.Foo#S //│ ║ ^^^^ //│ ╟── Required by contextual parameter declaration: //│ ║ l.101: class Foo(using T) with diff --git a/hkmc2/shared/src/test/mlscript/ctx/ExplicitlySpec.mls b/hkmc2/shared/src/test/mlscript/ctx/ExplicitlySpec.mls index 97f2bf0ce6..de199ef4fb 100644 --- a/hkmc2/shared/src/test/mlscript/ctx/ExplicitlySpec.mls +++ b/hkmc2/shared/src/test/mlscript/ctx/ExplicitlySpec.mls @@ -72,13 +72,18 @@ f(using 32) //│ ║ l.63: fun f(using i: Int)(j: Int)(using k: Int)(l: Num): Int = i + j + k + l //│ ║ ^^^^^^ //│ ╙── Missing instance: Expected: Int; Available: ‹none available› +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. //│ = fun :e :re f(using 32)(8) //│ ╔══[COMPILATION ERROR] Cannot query instance of type Int for call: -//│ ║ l.79: f(using 32)(8) +//│ ║ l.84: f(using 32)(8) //│ ║ ^^^^^^^^^^^^^^ //│ ╟── Required by contextual parameter declaration: //│ ║ l.63: fun f(using i: Int)(j: Int)(using k: Int)(l: Num): Int = i + j + k + l diff --git a/hkmc2/shared/src/test/mlscript/invalml/InvalMLCodeGen.mls b/hkmc2/shared/src/test/mlscript/invalml/InvalMLCodeGen.mls index ad5acd3d0a..fc046c4b80 100644 --- a/hkmc2/shared/src/test/mlscript/invalml/InvalMLCodeGen.mls +++ b/hkmc2/shared/src/test/mlscript/invalml/InvalMLCodeGen.mls @@ -81,6 +81,16 @@ data class Foo(x: Int) //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "Foo", ["x"]]; //│ }); +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1387) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1387': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. +//│ FAILURE: Unexpected internal error +//│ FAILURE LOCATION: softAssert (Lowering.scala:1390) +//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1390': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) +//│ ╟── The compilation result may be incorrect. +//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— From 197acdef19f9733ad127570e651f67c10478f74a Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 1 Jun 2026 23:44:27 +0800 Subject: [PATCH 21/48] codegen: Update printer to print type annotations on `let` --- .../main/scala/hkmc2/codegen/Printer.scala | 14 ++++--- .../src/test/mlscript/codegen/ErasedType.mls | 37 ++++++++++--------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index 182b0ff5f0..512c7e9cf7 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -68,7 +68,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case Assign(_: NoSymbol, rhs, rest) => doc"do ${print(rhs)}; # ${print(rest)}" case Assign(lhs: (LocalVarSymbol | TermSymbol), rhs, rest) => - doc"set ${print(lhs)}${erasedTypeAnnot(lhs)} = ${print(rhs)}; # ${print(rest)}" + doc"set ${print(lhs)} = ${print(rhs)}; # ${print(rest)}" case asf @ AssignField(lhs, nme, rhs, rest) => doc"set ${print(lhs)}.${showMemberSymbol(nme.name, asf.symbol)} = ${print(rhs)}; # ${print(rest)}" case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => @@ -78,7 +78,9 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case Scoped(syms, body) => scope.nest.givenIn: import hkmc2.given_Ordering_Uid // Not sure why needed... - val names = syms.toList.sortBy(_.uid).map(s => scope.allocateName(s)) + val names = syms.toList.sortBy(_.uid).map: + case sym: LocalVarSymbol => doc"${scope.allocateName(sym)}${erasedTypeAnnot(sym)}" + case bms: BlockMemberSymbol => doc"${scope.allocateName(bms)}${bms.tsym.fold(doc"")(erasedTypeAnnot(_))}" doc"let ${names.mkDocument(", ")}; # ${print(body)}" case End(msg) if msg.nonEmpty && config.commentGeneratedCode => doc"end /* ${msg} */" case End(_) => doc"end" @@ -235,10 +237,12 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): // * therefore, we want to avoid printing them with fresh names but use their `dbgName`s instead. scope.nest.givenIn: import hkmc2.given_Ordering_Uid // Not sure why needed... + val symPrinter = summon[SymbolPrinter] val names = syms.toList.sortBy(_.uid).map: - case s: TempSymbol => scope.allocateName(s) - case s => summon[SymbolPrinter].printSymbol(s) - doc"let ${names.mkString(", ")}; # ${print(body)}" + case s: TempSymbol => doc"${scope.allocateName(s)}${erasedTypeAnnot(s)}" + case s: LocalVarSymbol => doc"${symPrinter.printSymbol(s)}${erasedTypeAnnot(s)}" + case s: BlockMemberSymbol => doc"${symPrinter.printSymbol(s)}${s.tsym.fold(doc"")(erasedTypeAnnot(_))}" + doc"let ${names.mkDocument(", ")}; # ${print(body)}" case m => print(m) }" diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index d567b18498..7f046d9102 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -11,7 +11,7 @@ let f = 1.5 let s = "str" let b = true //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let i⁰, f⁰, s⁰, b⁰; set i⁰: ? = 1; set f⁰: ? = 1.5; set s⁰: ? = "str"; set b⁰: ? = true; end +//│ let i⁰: ?, f⁰: ?, s⁰: ?, b⁰: ?; set i⁰ = 1; set f⁰ = 1.5; set s⁰ = "str"; set b⁰ = true; end // Tuple -> Array; record -> Object (the `NoSymbol` top case). @@ -19,11 +19,11 @@ let b = true let t = [i, s] let r = {x: i, y: s} //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let t⁰, r⁰, x⁰, y⁰; -//│ set t⁰: ? = [i⁰, s⁰]; -//│ set x⁰: ? = i⁰; -//│ set y⁰: ? = s⁰; -//│ set r⁰: ? = { "x": x⁰, "y": y⁰ }; +//│ let t⁰: ?, r⁰: ?, x⁰: ?, y⁰: ?; +//│ set t⁰ = [i⁰, s⁰]; +//│ set x⁰ = i⁰; +//│ set y⁰ = s⁰; +//│ set r⁰ = { "x": x⁰, "y": y⁰ }; //│ end @@ -32,12 +32,12 @@ let r = {x: i, y: s} class Foo(x: Int) let foo = new Foo(123) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let Foo⁰, foo⁰; +//│ let Foo⁰: ?, foo⁰: ?; //│ define Foo⁰ as class Foo² { //│ private val x¹: Int; //│ constructor Foo¹(x: Int) { set Foo².this.x¹ = x; end } //│ }; -//│ set foo⁰: ? = new Foo²(123); +//│ set foo⁰ = new Foo²(123); //│ end @@ -46,12 +46,12 @@ let foo = new Foo(123) class Foo(val x: Str) let foo = new Foo(123) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let Foo³, foo¹; +//│ let Foo³: ?, foo¹: ?; //│ define Foo³ as class Foo⁵ { //│ val x²: Str; //│ constructor Foo⁴(x: Str) { define x² as val x³: Str = x; end } //│ }; -//│ set foo¹: ? = new Foo⁵(123); +//│ set foo¹ = new Foo⁵(123); //│ end @@ -61,18 +61,18 @@ fun mk(a) = new Foo(a) let made = mk(1) let sel = made.x //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let mk⁰, made⁰, sel⁰, selRes; +//│ let mk⁰: ?, made⁰: ?, sel⁰: ?, selRes: ?; //│ define mk⁰ as fun mk¹(a: ?) { //│ return new Foo⁵(a) //│ }; -//│ set made⁰: ? = mk¹(1); -//│ set selRes: ? = made⁰.x﹖; +//│ set made⁰ = mk¹(1); +//│ set selRes = made⁰.x﹖; //│ do made⁰.x$__checkNotMethod﹖; //│ match selRes //│ undefined => //│ throw new globalThis⁰.Error﹖("Access to required field 'x' yielded 'undefined'") //│ else -//│ set sel⁰: ? = selRes; +//│ set sel⁰ = selRes; //│ end //│ end @@ -81,7 +81,7 @@ let sel = made.x :siret fun add(x: Int, y: Int) = x + y //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let add⁰; define add⁰ as fun add¹(x: Int, y: Int) { return +⁰(x, y) }; end +//│ let add⁰: ?; define add⁰ as fun add¹(x: Int, y: Int) { return +⁰(x, y) }; end // Optimized IR with erased types: surfaces where a type relaxes to `: ?` @@ -91,6 +91,9 @@ fun add(x: Int, y: Int) = x + y fun id(a) = a id(new Foo(1)) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let id⁰, tmp; define id⁰ as fun id¹(a: ?) { return a }; set tmp: Foo⁵ = new Foo⁵(1); return id¹(tmp) +//│ let id⁰: ?, tmp: Foo⁵; +//│ define id⁰ as fun id¹(a: ?) { return a }; +//│ set tmp = new Foo⁵(1); +//│ return id¹(tmp) //│ ——————————————| Optimized IR |—————————————————————————————————————————————————————————————————————— -//│ let id⁰; define id⁰ as fun id¹(a: ?) { return a }; return new Foo⁵(1) +//│ let id⁰: ?; define id⁰ as fun id¹(a: ?) { return a }; return new Foo⁵(1) From e3ff6efe3c4a51db2d676ed5cd811533d6ffa80d Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 2 Jun 2026 14:14:09 +0800 Subject: [PATCH 22/48] codegen: Do not show annotations for types --- .../src/main/scala/hkmc2/codegen/Printer.scala | 4 ++-- .../src/test/mlscript/codegen/ErasedType.mls | 15 ++++++--------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index 512c7e9cf7..f9acc55ac5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -80,7 +80,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): import hkmc2.given_Ordering_Uid // Not sure why needed... val names = syms.toList.sortBy(_.uid).map: case sym: LocalVarSymbol => doc"${scope.allocateName(sym)}${erasedTypeAnnot(sym)}" - case bms: BlockMemberSymbol => doc"${scope.allocateName(bms)}${bms.tsym.fold(doc"")(erasedTypeAnnot(_))}" + case bms: BlockMemberSymbol => doc"${scope.allocateName(bms)}" doc"let ${names.mkDocument(", ")}; # ${print(body)}" case End(msg) if msg.nonEmpty && config.commentGeneratedCode => doc"end /* ${msg} */" case End(_) => doc"end" @@ -241,7 +241,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): val names = syms.toList.sortBy(_.uid).map: case s: TempSymbol => doc"${scope.allocateName(s)}${erasedTypeAnnot(s)}" case s: LocalVarSymbol => doc"${symPrinter.printSymbol(s)}${erasedTypeAnnot(s)}" - case s: BlockMemberSymbol => doc"${symPrinter.printSymbol(s)}${s.tsym.fold(doc"")(erasedTypeAnnot(_))}" + case s: BlockMemberSymbol => doc"${symPrinter.printSymbol(s)}" doc"let ${names.mkDocument(", ")}; # ${print(body)}" case m => print(m) }" diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index 7f046d9102..836a56ace5 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -32,7 +32,7 @@ let r = {x: i, y: s} class Foo(x: Int) let foo = new Foo(123) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let Foo⁰: ?, foo⁰: ?; +//│ let Foo⁰, foo⁰: ?; //│ define Foo⁰ as class Foo² { //│ private val x¹: Int; //│ constructor Foo¹(x: Int) { set Foo².this.x¹ = x; end } @@ -46,7 +46,7 @@ let foo = new Foo(123) class Foo(val x: Str) let foo = new Foo(123) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let Foo³: ?, foo¹: ?; +//│ let Foo³, foo¹: ?; //│ define Foo³ as class Foo⁵ { //│ val x²: Str; //│ constructor Foo⁴(x: Str) { define x² as val x³: Str = x; end } @@ -61,7 +61,7 @@ fun mk(a) = new Foo(a) let made = mk(1) let sel = made.x //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let mk⁰: ?, made⁰: ?, sel⁰: ?, selRes: ?; +//│ let mk⁰, made⁰: ?, sel⁰: ?, selRes: ?; //│ define mk⁰ as fun mk¹(a: ?) { //│ return new Foo⁵(a) //│ }; @@ -81,7 +81,7 @@ let sel = made.x :siret fun add(x: Int, y: Int) = x + y //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let add⁰: ?; define add⁰ as fun add¹(x: Int, y: Int) { return +⁰(x, y) }; end +//│ let add⁰; define add⁰ as fun add¹(x: Int, y: Int) { return +⁰(x, y) }; end // Optimized IR with erased types: surfaces where a type relaxes to `: ?` @@ -91,9 +91,6 @@ fun add(x: Int, y: Int) = x + y fun id(a) = a id(new Foo(1)) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let id⁰: ?, tmp: Foo⁵; -//│ define id⁰ as fun id¹(a: ?) { return a }; -//│ set tmp = new Foo⁵(1); -//│ return id¹(tmp) +//│ let id⁰, tmp: Foo⁵; define id⁰ as fun id¹(a: ?) { return a }; set tmp = new Foo⁵(1); return id¹(tmp) //│ ——————————————| Optimized IR |—————————————————————————————————————————————————————————————————————— -//│ let id⁰: ?; define id⁰ as fun id¹(a: ?) { return a }; return new Foo⁵(1) +//│ let id⁰; define id⁰ as fun id¹(a: ?) { return a }; return new Foo⁵(1) From df3b6aedde7298d76e371ae8af139d871c398b17 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 2 Jun 2026 16:46:09 +0800 Subject: [PATCH 23/48] [WIP] codegen/wasm: Restrict type of local `$this` --- .../scala/hkmc2/codegen/wasm/text/Ctx.scala | 15 +++++++++-- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 27 +++++++++++++++---- .../shared/src/test/mlscript/wasm/Basics.mls | 6 ++--- .../test/mlscript/wasm/ClassInheritance.mls | 4 +-- .../src/test/mlscript/wasm/ClassMethods.mls | 2 +- .../src/test/mlscript/wasm/ControlFlow.mls | 2 +- .../src/test/mlscript/wasm/Matching.mls | 4 +-- .../src/test/mlscript/wasm/ScopedLocals.mls | 2 +- .../src/test/mlscript/wasm/SingletonUnit.mls | 2 +- .../src/test/mlscript/wasm/Singletons.mls | 4 +-- .../src/test/mlscript/wasm/VirtualMethods.mls | 4 +-- 11 files changed, 50 insertions(+), 22 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala index f961706435..e519c5274d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala @@ -201,7 +201,7 @@ class FuncInfo( val body: Expr, val exportName: Opt[Str], val wrapId: Opt[Str] -> Opt[Str] = N -> N, -)(using Ctx, Raise) extends ToWat: +)(using Ctx, Raise, State) extends ToWat: /** Symbolic identifier for the function. */ val id = SymIdx(summon[Ctx].funcScp.allocateOrGetNameWrapped(sym, wrapId)) @@ -220,7 +220,7 @@ class FuncInfo( getSignatureType.toWat.surroundUnlessEmpty(doc" ") } #{ ${ locals.map: p => - doc"(local ${p._2.toWat} ${RefType.anyref.toWat})" + doc"(local ${p._2.toWat} ${p._1.localRefType.toWat})" .mkDocument(doc" # ").surroundUnlessEmpty(doc" # ") } # ${body.toWat} #} )""" end FuncInfo @@ -404,6 +404,17 @@ class FunctionCtx(_params: Ls[ParamList], thisSym: Opt[InnerSymbol])(using Raise */ def locals: Seq[ValueSymbol -> SymIdx] = _locals.map(l => l -> SymIdx(localScp.lookup_!(l, N))).toSeq + /** The declared Wasm reference type of the param/local slot for `sym`. + * + * Parameters are uniformly `anyref`: their declared type is fixed by the shared call/vtable + * calling convention, independent of `sym.erasedType` (e.g. a virtually-dispatched method's + * `this` must stay `anyref` to match the shared vtable signature even when its erased type names + * a concrete class). Local slots derive their type from the symbol's erased type via + * [[localRefType]]. + */ + def slotRefType(sym: ValueSymbol)(using Ctx): RefType = + if params.exists(_._1 == sym) then RefType.anyref else sym.localRefType + /** Pushes a label target for the dynamic extent of `body` and pops it afterwards. * * The `body` function is given a [[LabelTarget]] containing the `break` and `continue` labels corresponding to diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 0e2966da34..0e5281b554 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -28,7 +28,7 @@ extension (instr: FoldedInstr) extension (et: ErasedType) /** Returns the corresponding Wasm type for this [[`ErasedType`]]. */ - private def wasmType(using Ctx): Opt[RefType] = + private def wasmType(using Ctx): Opt[RefType] = import Ctx.ctx et match case ErasedType.Primitive(PrimitiveType.Int | PrimitiveType.Int31 | PrimitiveType.Bool) => @@ -37,6 +37,22 @@ extension (et: ErasedType) csym.asBlkMember.flatMap(ctx.getType).map(RefType(_, nullable = false)) case _ => N +extension (sym: ValueSymbol) + /** The Wasm reference type a *local* slot for `sym` should be declared with. + * + * Use [[FunctionCtx.slowRefType]] for parameter slots, which handles `anyref` widening due to virtual dispatch + * calling conventions. + */ + private[text] def localRefType(using Ctx, State): RefType = + import Ctx.ctx + sym match + case isym: InnerSymbol => + val structSym = isym.asBlkMember orElse: + Option.when(isym eq State.unitSymbol): + State.unitBlockMemberSymbol + structSym.flatMap(ctx.getType).map(RefType(_, nullable = false)).getOrElse(RefType.anyref) + case _ => RefType.anyref + object WatBuilder: /** The maximum number of characters taken to be part of the identifier asscoiated with string constants. */ val StringConstantIdentMaxLength = 16 @@ -1103,7 +1119,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: lastWords(s"ValueSymbol `$ts` (${ts.getClass.getSimpleName}) cannot be resolved as a variable") case l => funcCtx.lookupLocal(l) match - case S(localIdx) => local.get(localIdx, RefType.anyref) + case S(localIdx) => local.get(localIdx, funcCtx.slotRefType(l)) case N if ctx.containsGlobal(l) => global.get(ctx.getGlobal_!(l), ctx.getGlobalType_!(l).globalType.valType) case _ => @@ -1237,9 +1253,10 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: singletonInfoFor(sym) match case S(info) => singletonGlobalGet(info) case N => - // TODO(Derppening): Remove `ref.cast` once erased-typed IR is implemented - ref.cast( - local.get(funcCtx.lookupLocal_!(sym, sym.toLoc), RefType.anyref), + // The cast is necessary if `$this` is a parameter rather than a local, since `$this` parameters are typed as + // `anyref` to support virtual dispatch + castConserve( + local.get(funcCtx.lookupLocal_!(sym, sym.toLoc), funcCtx.slotRefType(sym)), RefType( sym.asBlkMember.fold(baseObjectTypeIdx)(ctx.getType_!(_)), nullable = false, diff --git a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls index 0086150abc..0ff819633f 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls @@ -121,7 +121,7 @@ class Foo(val a) //│ (return //│ (local.get $this)))) //│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $a (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Foo)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Foo @@ -194,7 +194,7 @@ class Foo(val x) with //│ (return //│ (local.get $this)))) //│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $x (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Foo)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Foo @@ -266,7 +266,7 @@ O.y //│ (return //│ (local.get $this)))) //│ (func $O_ctor (type $O_ctor) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $O)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $O diff --git a/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls b/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls index f4e213da63..d602b1a536 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls @@ -58,7 +58,7 @@ c.Parent#x + (c is Parent) //│ (return //│ (local.get $this)))) //│ (func $Parent_ctor (export "Parent") (type $Parent_ctor) (param $x (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Parent)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Parent @@ -94,7 +94,7 @@ c.Parent#x + (c is Parent) //│ (return //│ (local.get $this)))) //│ (func $Child_ctor (export "Child") (type $Child_ctor) (param $y (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Child)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Child diff --git a/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls b/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls index 29b1c24b48..ed368a2470 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls @@ -36,7 +36,7 @@ a.A#get() + A(2).get() //│ (return //│ (local.get $this)))) //│ (func $A_ctor (export "A") (type $A_ctor) (param $x (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $A)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $A diff --git a/hkmc2/shared/src/test/mlscript/wasm/ControlFlow.mls b/hkmc2/shared/src/test/mlscript/wasm/ControlFlow.mls index 1fc45c466a..e0fef568a3 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ControlFlow.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ControlFlow.mls @@ -37,7 +37,7 @@ let i = 0 in //│ (return //│ (local.get $this)))) //│ (func $Unit_Unit (type $Unit_ctor) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Unit)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Unit diff --git a/hkmc2/shared/src/test/mlscript/wasm/Matching.mls b/hkmc2/shared/src/test/mlscript/wasm/Matching.mls index bbd424dd8e..fda549d3a9 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Matching.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Matching.mls @@ -72,7 +72,7 @@ if Bar(true) is //│ (return //│ (local.get $this)))) //│ (func $Bar_ctor (export "Bar") (type $Bar_ctor) (param $y (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Bar)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Bar @@ -96,7 +96,7 @@ if Bar(true) is //│ (return //│ (local.get $this)))) //│ (func $Baz_ctor (export "Baz") (type $Baz_ctor) (param $z (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Baz)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Baz diff --git a/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls b/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls index cc19959348..4a5b9f7311 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls @@ -162,7 +162,7 @@ class Foo(val a, val b) //│ (return //│ (local.get $this)))) //│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $a (ref null any)) (param $b (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Foo)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Foo diff --git a/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls b/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls index f1479afa2f..7cb4465dcd 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls @@ -25,7 +25,7 @@ //│ (return //│ (local.get $this)))) //│ (func $Unit_Unit (type $Unit_ctor) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Unit)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Unit diff --git a/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls b/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls index 64a2892d2d..f1d81d550c 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls @@ -44,7 +44,7 @@ Bar.y //│ (return //│ (local.get $this)))) //│ (func $Foo_ctor (type $Foo_ctor) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Foo)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Foo @@ -68,7 +68,7 @@ Bar.y //│ (return //│ (local.get $this)))) //│ (func $Bar_ctor (type $Bar_ctor) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $Bar)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $Bar diff --git a/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls b/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls index d4e2725b65..210ef70f59 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls @@ -54,7 +54,7 @@ callF(B()) //│ (return //│ (local.get $this)))) //│ (func $A_ctor (export "A") (type $A_ctor) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $A)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $A @@ -81,7 +81,7 @@ callF(B()) //│ (return //│ (local.get $this)))) //│ (func $B_ctor (export "B") (type $B_ctor) (result (ref null any)) -//│ (local $this (ref null any)) +//│ (local $this (ref $B)) //│ (block (result (ref null any)) //│ (local.set $this //│ (struct.new $B From b33b14644fdeb0540ba3b8b3bd290922c4f4fa75 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 3 Jun 2026 20:11:56 +0800 Subject: [PATCH 24/48] Revert dd94b46 --- .../src/main/scala/hkmc2/codegen/Block.scala | 3 +- .../test/mlscript/basics/MultiParamLists.mls | 50 ------------------- .../src/test/mlscript/ctx/ClassCtxParams.mls | 14 +----- .../src/test/mlscript/ctx/ExplicitlySpec.mls | 7 +-- .../test/mlscript/invalml/InvalMLCodeGen.mls | 10 ---- 5 files changed, 5 insertions(+), 79 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 5c294f8985..ef0c06d1a6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -896,7 +896,8 @@ trait HasRefinableErasedType extends HasErasedType: /** Refines the erased type, or raises a soft assertion if the type was already previously refined. */ def refineErasedType(newType: ErasedType)(using Line, FileName, Raise): Unit = - softAssert(erasedType.isEmpty, s"Cannot refine already-refined erased type $erasedType to $newType") + // TODO(Derppening): Restore `erasedType.isEmpty` once JS sanitization is converted into a pass + softAssert(erasedType.forall(_ == newType), s"Cannot refine already-refined erased type $erasedType to $newType") if erasedType.isEmpty then erasedType = S(newType) extension (lit: Literal) diff --git a/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls b/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls index ed1c58f075..3aa06fd0e9 100644 --- a/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls +++ b/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls @@ -7,11 +7,6 @@ fun f(n1: Int): Int = n1 //│ ————————————| JS (unsanitized) |———————————————————————————————————————————————————————————————————— //│ let f; f = function f(n1) { return n1 }; -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— f(42) @@ -27,16 +22,6 @@ f(42) fun f(n1: Int)(n2: Int): Int = (10 * n1 + n2) //│ ————————————| JS (unsanitized) |———————————————————————————————————————————————————————————————————— //│ let f1; f1 = function f(n1) { return (n2) => { let tmp; tmp = 10 * n1; return tmp + n2 } }; -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— // TODO compile this to @@ -62,21 +47,6 @@ fun f(n1: Int)(n2: Int)(n3: Int): Int = 10 * (10 * n1 + n2) + n3 //│ } //│ } //│ }; -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— f(4)(2)(0) @@ -103,26 +73,6 @@ fun f(n1: Int)(n2: Int)(n3: Int)(n4: Int): Int = 10 * (10 * (10 * n1 + n2) + n3) //│ } //│ } //│ }; -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— f(3)(0)(3)(1) diff --git a/hkmc2/shared/src/test/mlscript/ctx/ClassCtxParams.mls b/hkmc2/shared/src/test/mlscript/ctx/ClassCtxParams.mls index 3d17ac4389..cb6e047e12 100644 --- a/hkmc2/shared/src/test/mlscript/ctx/ClassCtxParams.mls +++ b/hkmc2/shared/src/test/mlscript/ctx/ClassCtxParams.mls @@ -115,26 +115,16 @@ class Foo(using T) with //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "Foo", [null]]; //│ }); -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1387) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1387': Cannot refine already-refined erased type Some(AnyRef(false,class:T)) to AnyRef(false,class:T) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1389) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1389': Cannot refine already-refined erased type Some(AnyRef(false,class:T)) to AnyRef(false,class:T) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— :todo :sjs x => x.Foo#S //│ ╔══[COMPILATION ERROR] Class 'Foo' does not contain member 'S'. -//│ ║ l.132: x => x.Foo#S +//│ ║ l.122: x => x.Foo#S //│ ╙── ^ //│ ╔══[COMPILATION ERROR] Cannot query instance of type T for call: -//│ ║ l.132: x => x.Foo#S +//│ ║ l.122: x => x.Foo#S //│ ║ ^^^^ //│ ╟── Required by contextual parameter declaration: //│ ║ l.101: class Foo(using T) with diff --git a/hkmc2/shared/src/test/mlscript/ctx/ExplicitlySpec.mls b/hkmc2/shared/src/test/mlscript/ctx/ExplicitlySpec.mls index de199ef4fb..97f2bf0ce6 100644 --- a/hkmc2/shared/src/test/mlscript/ctx/ExplicitlySpec.mls +++ b/hkmc2/shared/src/test/mlscript/ctx/ExplicitlySpec.mls @@ -72,18 +72,13 @@ f(using 32) //│ ║ l.63: fun f(using i: Int)(j: Int)(using k: Int)(l: Num): Int = i + j + k + l //│ ║ ^^^^^^ //│ ╙── Missing instance: Expected: Int; Available: ‹none available› -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1397) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1397': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. //│ = fun :e :re f(using 32)(8) //│ ╔══[COMPILATION ERROR] Cannot query instance of type Int for call: -//│ ║ l.84: f(using 32)(8) +//│ ║ l.79: f(using 32)(8) //│ ║ ^^^^^^^^^^^^^^ //│ ╟── Required by contextual parameter declaration: //│ ║ l.63: fun f(using i: Int)(j: Int)(using k: Int)(l: Num): Int = i + j + k + l diff --git a/hkmc2/shared/src/test/mlscript/invalml/InvalMLCodeGen.mls b/hkmc2/shared/src/test/mlscript/invalml/InvalMLCodeGen.mls index fc046c4b80..ad5acd3d0a 100644 --- a/hkmc2/shared/src/test/mlscript/invalml/InvalMLCodeGen.mls +++ b/hkmc2/shared/src/test/mlscript/invalml/InvalMLCodeGen.mls @@ -81,16 +81,6 @@ data class Foo(x: Int) //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "Foo", ["x"]]; //│ }); -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1387) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1387': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. -//│ FAILURE: Unexpected internal error -//│ FAILURE LOCATION: softAssert (Lowering.scala:1390) -//│ ╔══[INTERNAL ERROR] Compiler reached an unexpected state at 'Lowering.scala:1390': Cannot refine already-refined erased type Some(Primitive(Int)) to Primitive(Int) -//│ ╟── The compilation result may be incorrect. -//│ ╙── This is a compiler bug; please report it to the maintainers. //│ —————————————————| Output |————————————————————————————————————————————————————————————————————————— From 2838dadbcc9755e931fce4037672504f40a00222 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 4 Jun 2026 13:35:03 +0800 Subject: [PATCH 25/48] Update binaryen.js --- package-lock.json | 14 +++++++------- package.json | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/package-lock.json b/package-lock.json index 912717660c..e0ca4dbe9b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -7,7 +7,7 @@ "name": "mlscript", "dependencies": { "benchmark": "^2.1.4", - "binaryen": "^129.0.0", + "binaryen": "^130.0.0", "typescript": "^4.7.4" } }, @@ -22,9 +22,9 @@ } }, "node_modules/binaryen": { - "version": "129.0.0", - "resolved": "https://registry.npmjs.org/binaryen/-/binaryen-129.0.0.tgz", - "integrity": "sha512-NyF5J0SfRoLDthpPh36FGTycOEv3Eqnkq3+mP5Cqt6iD9BLGGJMEVuPzu81nhLy2MMpPKmRTM9VLZihfyRQv8A==", + "version": "130.0.0", + "resolved": "https://registry.npmjs.org/binaryen/-/binaryen-130.0.0.tgz", + "integrity": "sha512-XDrb+zql0RbFPKgj7MuH9zOc78R3Fa/P/VSGnnpdwYsvNZPWjcMYMdAkKCOQEL2A7yqgjSMTDRFp6gfSDW+/QQ==", "license": "Apache-2.0", "bin": { "wasm-as": "bin/wasm-as", @@ -75,9 +75,9 @@ } }, "binaryen": { - "version": "129.0.0", - "resolved": "https://registry.npmjs.org/binaryen/-/binaryen-129.0.0.tgz", - "integrity": "sha512-NyF5J0SfRoLDthpPh36FGTycOEv3Eqnkq3+mP5Cqt6iD9BLGGJMEVuPzu81nhLy2MMpPKmRTM9VLZihfyRQv8A==" + "version": "130.0.0", + "resolved": "https://registry.npmjs.org/binaryen/-/binaryen-130.0.0.tgz", + "integrity": "sha512-XDrb+zql0RbFPKgj7MuH9zOc78R3Fa/P/VSGnnpdwYsvNZPWjcMYMdAkKCOQEL2A7yqgjSMTDRFp6gfSDW+/QQ==" }, "lodash": { "version": "4.18.1", diff --git a/package.json b/package.json index ba3efb9c6a..56c503ef1f 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "mlscript", "dependencies": { "benchmark": "^2.1.4", - "binaryen": "^129.0.0", + "binaryen": "^130.0.0", "typescript": "^4.7.4" } } From 178a3391cca66e2ffaeeff5e48b195009bc44e6e Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 4 Jun 2026 15:33:29 +0800 Subject: [PATCH 26/48] WIP: Implement `FuncRef`, refine in `Lowering` --- .../src/main/scala/hkmc2/codegen/Block.scala | 4 ++ .../main/scala/hkmc2/codegen/Lowering.scala | 13 ++++++ .../main/scala/hkmc2/codegen/Printer.scala | 13 +++++- .../src/test/mlscript/codegen/ErasedType.mls | 45 ++++++++++++++++--- 4 files changed, 69 insertions(+), 6 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index ef0c06d1a6..ddd499740f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -872,6 +872,9 @@ enum ErasedType: */ case AnyRef(rsc: Bool, csym: ClassLikeSymbol | NoSymbol) + /** A reference to a function of a possibly-known shape. */ + case FuncRef(sig: Opt[Ls[Opt[ErasedType]] -> Opt[ErasedType]]) + /** An primitive type. */ case Primitive(prim: PrimitiveType) @@ -879,6 +882,7 @@ enum ErasedType: def sym(using Ctx, State): ClassLikeSymbol = this match case AnyRef(_, csym: ClassLikeSymbol) => csym case AnyRef(_, _: NoSymbol) => ctx.builtins.Object + case FuncRef(_) => ctx.builtins.Function case Primitive(prim) => prim.sym /** Trait representing a Block IR element that has an [[`ErasedType`]]. */ diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 29ac4cbfb5..fb542e3ef5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -249,6 +249,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): blockImpl(stats, res)))(using LoweringCtx.nestFunc) case syntax.Fun => val (paramLists, bodyBlock) = setupFunctionOrByNameDef(td.params, bod, S(td.sym.nme)) + refineFunDefnType(td.tsym, paramLists, td.sign) val cfgOverride = td.extraAnnotations.collectFirst: case Annot.Config(modify) => modify(config) Define(FunDefn(td.owner, td.sym, td.tsym, paramLists, bodyBlock)(cfgOverride, td.annotations), @@ -1208,6 +1209,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): .flatMap: td => td.body.map: bod => val (paramLists, bodyBlock) = setupFunctionDef(td.params, bod, S(td.sym.nme)) + refineFunDefnType(td.tsym, paramLists, td.sign) reportAnnotations(td, td.extraAnnotations) val cfgOverride = td.extraAnnotations.collectFirst: case Annot.Config(modify) => modify(config) @@ -1390,6 +1392,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case bms: BlockMemberSymbol => bms.tsym.foreach(_.refineErasedType(et)) case _ => + /** Populates a function definition symbol's erased type with a [[`ErasedType.FuncRef`]] + * derived from its (already-refined) parameter symbols and return-type annotation. + * + * The parameters of curried functions are flattened into a single list: this is lossy + * for the arrow shape but does not affect the rendered return type, the only consumer + * today. An unannotated return remains `N` (refined in a later phase). */ + private def refineFunDefnType(tsym: TermSymbol, paramLists: Ls[ParamList], sign: Opt[Term]): Unit = + val params = paramLists.flatMap(_.params).map(_.sym.erasedType) + val ret = sign.flatMap(eraseSign) + tsym.refineErasedType(ErasedType.FuncRef(S(params -> ret))) + def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str]) (using LoweringCtx): (List[ParamList], Block) = paramLists.foreach: pl => diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index f9acc55ac5..40263eb8b7 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -33,11 +33,22 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): def print(et: ErasedType)(using Scope): Document = et match case ErasedType.AnyRef(rsc, csym: ClassLikeSymbol) => doc"${if rsc then "rsc " else ""}${print(csym)}" case ErasedType.AnyRef(rsc, _: NoSymbol) => doc"${if rsc then "rsc " else ""}Object" + case ErasedType.FuncRef(N) => doc"Function" + case ErasedType.FuncRef(S(params -> ret)) => + doc"(${params.map(_.fold(doc"?")(print)).mkDocument(sep = doc", ")}) => ${ret.fold(doc"?")(print)}" case ErasedType.Primitive(prim) => doc"${prim.toString}" def erasedTypeAnnot(x: HasErasedType)(using Scope): Document = if !summon[ShowCfg].showErasedTypes then doc"" else doc": ${x.erasedType.fold(doc"?")(print)}" + + /** Renders a function's return type, projected from the `FuncRef` carried by its + * definition symbol. Nothing is rendered when the symbol carries no `FuncRef`. */ + def returnTypeAnnot(dSym: TermSymbol)(using Scope): Document = + if !summon[ShowCfg].showErasedTypes then doc"" + else dSym.erasedType match + case S(ErasedType.FuncRef(sig)) => doc": ${sig.flatMap(_._2).fold(doc"?")(print)}" + case _ => doc"" def print(blk: Block)(using Scope): Document = blk match case Match(scrut, arms, dflt, rest) => @@ -149,7 +160,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): val docParams = printParamLists(paramss) val docBody = print(body) val docStaged = if fun.isStaged then doc"staged " else doc"" - doc"${docStaged}fun ${print(dSym)}${docParams} ${bracedbk(docBody)}" + doc"${docStaged}fun ${print(dSym)}${docParams}${returnTypeAnnot(dSym)} ${bracedbk(docBody)}" case ValDefn(tsym, sym, rhs) => doc"val ${print(tsym)}${erasedTypeAnnot(tsym)} = ${print(rhs)}" case cls @ ClsLikeDefn(own, isym, sym, ctorSym, k, paramsOpt, auxParams, parentSym, methods, diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index 836a56ace5..778ede0e31 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -62,7 +62,7 @@ let made = mk(1) let sel = made.x //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— //│ let mk⁰, made⁰: ?, sel⁰: ?, selRes: ?; -//│ define mk⁰ as fun mk¹(a: ?) { +//│ define mk⁰ as fun mk¹(a: ?): ? { //│ return new Foo⁵(a) //│ }; //│ set made⁰ = mk¹(1); @@ -77,11 +77,43 @@ let sel = made.x //│ end -// Parameter annotations +// Parameter annotations. The return type is unannotated, so it stays `: ?` +// (residual inference is a later phase). :siret fun add(x: Int, y: Int) = x + y //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let add⁰; define add⁰ as fun add¹(x: Int, y: Int) { return +⁰(x, y) }; end +//│ let add⁰; define add⁰ as fun add¹(x: Int, y: Int): ? { return +⁰(x, y) }; end + + +// Return-type annotations are rendered at the `fun ...: T` definition site, +// for both primitive and class return types. +:siret +fun addUp(x: Int, y: Int): Int = x + y +fun makeFoo(n: Int): Foo = new Foo(n) +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let makeFoo⁰, addUp⁰; +//│ define addUp⁰ as fun addUp¹(x: Int, y: Int): Int { +//│ return +⁰(x, y) +//│ }; +//│ define makeFoo⁰ as fun makeFoo¹(n: Int): Foo⁵ { return new Foo⁵(n) }; +//│ end + + +// Method return-type annotations are rendered the same way as free functions. +:siret +class Calc() with + fun twice(x: Int): Int = x + x + fun whatever(a) = a +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let Calc⁰; +//│ define Calc⁰ as class Calc² { +//│ constructor Calc¹ +//│ method twice⁰ = fun twice¹(x: Int): Int { +//│ return +⁰(x, x) +//│ } +//│ method whatever⁰ = fun whatever¹(a: ?): ? { return a } +//│ }; +//│ end // Optimized IR with erased types: surfaces where a type relaxes to `: ?` @@ -91,6 +123,9 @@ fun add(x: Int, y: Int) = x + y fun id(a) = a id(new Foo(1)) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let id⁰, tmp: Foo⁵; define id⁰ as fun id¹(a: ?) { return a }; set tmp = new Foo⁵(1); return id¹(tmp) +//│ let id⁰, tmp: Foo⁵; +//│ define id⁰ as fun id¹(a: ?): ? { return a }; +//│ set tmp = new Foo⁵(1); +//│ return id¹(tmp) //│ ——————————————| Optimized IR |—————————————————————————————————————————————————————————————————————— -//│ let id⁰; define id⁰ as fun id¹(a: ?) { return a }; return new Foo⁵(1) +//│ let id⁰; define id⁰ as fun id¹(a: ?): ? { return a }; return new Foo⁵(1) From 4203e00c5d7d23ef45496ef569ea00f90aadff4f Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 4 Jun 2026 15:53:09 +0800 Subject: [PATCH 27/48] WIP: Add refinement of return types and operator types --- .../src/main/scala/hkmc2/codegen/Block.scala | 2 ++ .../main/scala/hkmc2/codegen/Lowering.scala | 28 ++++++++++++++---- .../main/scala/hkmc2/semantics/Symbol.scala | 29 ++++++++++++++++++- .../src/test/mlscript/codegen/ErasedType.mls | 28 +++++++++++++++--- 4 files changed, 76 insertions(+), 11 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index ddd499740f..bbe0c219af 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -1007,6 +1007,8 @@ sealed abstract class Result extends AutoLocated, HasErasedType: case Value.MemberRef(_, disamb: (ClassSymbol | ModuleOrObjectSymbol | TypeAliasSymbol)) => disamb.erasedType case Value.This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => clsOrMod.erasedType case Value.Lit(lit) => S(lit.erasedType) + case Call(Value.SimpleRef(bs: BuiltinSymbol), argss) => + bs.resultErasedType(argss.head.map(_.value.erasedType)) case _ => N /* mayRaiseEffects indicates whether this call may raise effect (algebraic effect), diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index fb542e3ef5..ea28e4b882 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -249,7 +249,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): blockImpl(stats, res)))(using LoweringCtx.nestFunc) case syntax.Fun => val (paramLists, bodyBlock) = setupFunctionOrByNameDef(td.params, bod, S(td.sym.nme)) - refineFunDefnType(td.tsym, paramLists, td.sign) + refineFunDefnType(td.tsym, paramLists, td.sign, bodyBlock) val cfgOverride = td.extraAnnotations.collectFirst: case Annot.Config(modify) => modify(config) Define(FunDefn(td.owner, td.sym, td.tsym, paramLists, bodyBlock)(cfgOverride, td.annotations), @@ -1209,7 +1209,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): .flatMap: td => td.body.map: bod => val (paramLists, bodyBlock) = setupFunctionDef(td.params, bod, S(td.sym.nme)) - refineFunDefnType(td.tsym, paramLists, td.sign) + refineFunDefnType(td.tsym, paramLists, td.sign, bodyBlock) reportAnnotations(td, td.extraAnnotations) val cfgOverride = td.extraAnnotations.collectFirst: case Annot.Config(modify) => modify(config) @@ -1392,15 +1392,31 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case bms: BlockMemberSymbol => bms.tsym.foreach(_.refineErasedType(et)) case _ => + /** Infers a function's return type from its body's `return`s, but only when every + * `return` agrees on a single known type (a conservative equal-or-`N` join); a body + * with conflicting or unknown returns stays `N`. Nested function/lambda bodies are not + * descended into. */ + private def inferReturn(body: Block): Opt[ErasedType] = + var rets: Ls[Opt[ErasedType]] = Nil + new BlockTraverserShallow: + override def applyBlock(b: Block): Unit = b match + case Return(res) => rets ::= res.erasedType + case _ => super.applyBlock(b) + .applyBlock(body) + rets match + case head :: tail if tail.forall(_ == head) => head + case _ => N + /** Populates a function definition symbol's erased type with a [[`ErasedType.FuncRef`]] - * derived from its (already-refined) parameter symbols and return-type annotation. + * derived from its (already-refined) parameter symbols and return type. The return type + * comes from the explicit annotation when present, otherwise it is inferred from the body. * * The parameters of curried functions are flattened into a single list: this is lossy * for the arrow shape but does not affect the rendered return type, the only consumer - * today. An unannotated return remains `N` (refined in a later phase). */ - private def refineFunDefnType(tsym: TermSymbol, paramLists: Ls[ParamList], sign: Opt[Term]): Unit = + * today. */ + private def refineFunDefnType(tsym: TermSymbol, paramLists: Ls[ParamList], sign: Opt[Term], body: Block): Unit = val params = paramLists.flatMap(_.params).map(_.sym.erasedType) - val ret = sign.flatMap(eraseSign) + val ret = sign.flatMap(eraseSign) orElse inferReturn(body) tsym.refineErasedType(ErasedType.FuncRef(S(params -> ret))) def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str]) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index eb329eb8dc..47a33b2f15 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -10,7 +10,7 @@ import hkmc2.utils.* import Elaborator.State import Tree.Ident -import hkmc2.codegen.{ErasedType, HasErasedType, erasedType} +import hkmc2.codegen.{ErasedType, PrimitiveType, HasErasedType, erasedType} import hkmc2.utils.SymbolSubst import hkmc2.codegen.HasRefinableErasedType @@ -286,6 +286,33 @@ class BuiltinSymbol case _ => Bot semantics.flow.Producer.Typ(typ) + /** The result [[`ErasedType`]] of applying this builtin operator to operands of the + * given erased types, or `N` if this symbol is not a recognized operator. Context-free; + * surfaces the result-type knowledge already implicit in `BlockSimplifier.builtinEval`. */ + def resultErasedType(args: Ls[Opt[ErasedType]]): Opt[ErasedType] = + import ErasedType.Primitive + def isStr(t: Opt[ErasedType]) = t.contains(Primitive(PrimitiveType.Str)) + def isInt(t: Opt[ErasedType]) = t.contains(Primitive(PrimitiveType.Int)) + def isNum(t: Opt[ErasedType]) = isInt(t) || t.contains(Primitive(PrimitiveType.Num)) + nme match + case "==" | "!=" | "<" | "<=" | ">" | ">=" | "===" | "!==" | "&&" | "||" | "!" => + S(Primitive(PrimitiveType.Bool)) + case "typeof" => S(Primitive(PrimitiveType.Str)) + // * `+` is overloaded (numeric add / string concat): only commit when the operands + // * decide it, otherwise leave it unknown (an unknown operand could be a `Str`). + case "+" => + if args.exists(isStr) then S(Primitive(PrimitiveType.Str)) + else if args.forall(isInt) then S(Primitive(PrimitiveType.Int)) + else if args.forall(isNum) then S(Primitive(PrimitiveType.Num)) + else N + // * The remaining arithmetic operators are numeric-only, so the result is always a + // * number; it is an `Int` only when every operand is known to be an `Int`. + case "-" | "*" | "%" => + if args.forall(isInt) then S(Primitive(PrimitiveType.Int)) else S(Primitive(PrimitiveType.Num)) + case "/" => S(Primitive(PrimitiveType.Num)) + case "~" => S(Primitive(PrimitiveType.Int)) + case _ => N + /** This is the outside-facing symbol associated to a possibly-overloaded * definition living in a block – e.g., a module or class. diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index 778ede0e31..0cdbf2e353 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -62,7 +62,7 @@ let made = mk(1) let sel = made.x //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— //│ let mk⁰, made⁰: ?, sel⁰: ?, selRes: ?; -//│ define mk⁰ as fun mk¹(a: ?): ? { +//│ define mk⁰ as fun mk¹(a: ?): Foo⁵ { //│ return new Foo⁵(a) //│ }; //│ set made⁰ = mk¹(1); @@ -77,12 +77,12 @@ let sel = made.x //│ end -// Parameter annotations. The return type is unannotated, so it stays `: ?` -// (residual inference is a later phase). +// Parameter annotations. The return type is unannotated but inferred from the +// body: `Int + Int` yields `Int`. :siret fun add(x: Int, y: Int) = x + y //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let add⁰; define add⁰ as fun add¹(x: Int, y: Int): ? { return +⁰(x, y) }; end +//│ let add⁰; define add⁰ as fun add¹(x: Int, y: Int): Int { return +⁰(x, y) }; end // Return-type annotations are rendered at the `fun ...: T` definition site, @@ -99,6 +99,26 @@ fun makeFoo(n: Int): Foo = new Foo(n) //│ end +// Inferred returns from the body terminal: a builtin comparison yields `Bool` +// (regardless of operand types), an instantiation yields the class, and an +// unconstrained identity return stays `: ?`. (Equality `==` lowers to a +// `Predef.equals` member call, not a builtin operator, so it stays `: ?`.) +:siret +fun lt(a, b) = a < b +fun mkFoo(a) = new Foo(a) +fun ident(a) = a +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let ident⁰, lt⁰, mkFoo⁰; +//│ define lt⁰ as fun lt¹(a: ?, b: ?): Bool { +//│ return <⁰(a, b) +//│ }; +//│ define mkFoo⁰ as fun mkFoo¹(a: ?): Foo⁵ { +//│ return new Foo⁵(a) +//│ }; +//│ define ident⁰ as fun ident¹(a: ?): ? { return a }; +//│ end + + // Method return-type annotations are rendered the same way as free functions. :siret class Calc() with From 202ef1dcf1945f0096d019210e2ed62e03b7eaef Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 11:48:47 +0800 Subject: [PATCH 28/48] Lowering: Correctly guard population of function types --- .../src/main/scala/hkmc2/codegen/Lowering.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index ea28e4b882..5aca313497 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -1411,13 +1411,19 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): * derived from its (already-refined) parameter symbols and return type. The return type * comes from the explicit annotation when present, otherwise it is inferred from the body. * + * This is a derived type that may be recomputed when a definition is lowered more than once + * (e.g. under `:lift`/`:effectHandlers`); since a later pass can infer a different return type, + * the result is recorded only the first time and left untouched afterwards (rather than going + * through the asserting [[`refineErasedType`]], which is meant for one-shot annotations). + * * The parameters of curried functions are flattened into a single list: this is lossy * for the arrow shape but does not affect the rendered return type, the only consumer * today. */ private def refineFunDefnType(tsym: TermSymbol, paramLists: Ls[ParamList], sign: Opt[Term], body: Block): Unit = - val params = paramLists.flatMap(_.params).map(_.sym.erasedType) - val ret = sign.flatMap(eraseSign) orElse inferReturn(body) - tsym.refineErasedType(ErasedType.FuncRef(S(params -> ret))) + if tsym.erasedType.isEmpty then + val params = paramLists.flatMap(_.params).map(_.sym.erasedType) + val ret = sign.flatMap(eraseSign) orElse inferReturn(body) + tsym.erasedType = S(ErasedType.FuncRef(S(params -> ret))) def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str]) (using LoweringCtx): (List[ParamList], Block) = From 4ddf2528599f0d86ca5e39ea75105093b0401128 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 11:53:01 +0800 Subject: [PATCH 29/48] codegen: Implement erased type relaxation When a variable is assigned twice using different-typed values. --- .../src/main/scala/hkmc2/codegen/Block.scala | 19 +++++ .../main/scala/hkmc2/codegen/Lowering.scala | 11 ++- .../src/test/mlscript/codegen/ErasedType.mls | 83 ++++++++++++++----- 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index bbe0c219af..31ea7d6027 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -904,6 +904,25 @@ trait HasRefinableErasedType extends HasErasedType: softAssert(erasedType.forall(_ == newType), s"Cannot refine already-refined erased type $erasedType to $newType") if erasedType.isEmpty then erasedType = S(newType) + /** Tracks whether [[`observeErasedType`]] has seen at least one assignment to this symbol. */ + private var erasedTypeObserved: Bool = false + + /** + * Observes an assignment of `observed` to this (possibly reassignable) symbol, joining it with + * any previously known type. Unlike [[`refineErasedType`]], which expects a single authoritative + * write, this tolerates multiple assignments of differing types: the first observation of an + * otherwise-unknown symbol is recorded exactly (so a single assignment of an unknown type keeps + * `N`), and any later disagreement — including with a type already set by an annotation — widens to + * the top type [[`ErasedType.ObjectRef`]]. The join is monotone (`N` → known → top), so it never + * re-narrows; a bare `N` therefore means "never observed". + */ + def observeErasedType(observed: Opt[ErasedType]): Unit = + if !erasedTypeObserved && erasedType.isEmpty then + erasedType = observed + else if erasedType != observed then + erasedType = S(ErasedType.ObjectRef) + erasedTypeObserved = true + extension (lit: Literal) def erasedType: ErasedType = lit match case Tree.UnitLit(_) => ErasedType.Primitive(PrimitiveType.Unit) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 5aca313497..16f0cabb3b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -537,9 +537,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): sym.owner match case S(owner) => AssignField(owner.asThis, sym.id, rhs, rest)(S(sym)) case N => nope - case sym: LocalVarSymbol => Assign(sym, rhs, rest) + case sym: LocalVarSymbol => + observeLocalErasedType(sym, rhs) + Assign(sym, rhs, rest) case sym => nope + /** Joins the erased type of a local variable's RHS into its (refinable) symbol. Only `VarSymbol`s + * are refinable; compiler-generated `TempSymbol`s carry their type from creation, so they are skipped. */ + private def observeLocalErasedType(sym: LocalVarSymbol, rhs: Result): Unit = sym match + case sym: HasRefinableErasedType => sym.observeErasedType(rhs.erasedType) + case _ => + private def defineSymbol(sym: Symbol, rhs: Result, rest: Block)(using LoweringCtx): Block = sym match case sym: TermSymbol => @@ -547,6 +555,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case S(owner) => AssignField(owner.asThis, sym.id, rhs, rest)(S(sym)) case N => lastWords(s"tried to define top-level symbol ${sym.showDbg} in a local scope") case sym: LocalVarSymbol => + observeLocalErasedType(sym, rhs) Assign(sym, rhs, rest) case sym => lastWords(s"tried to define non-variable symbol ${sym.showDbg}") diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index 0cdbf2e353..e864fdf8a7 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -4,14 +4,39 @@ //│ import Predef; end -// Literals: Int / Str / Num / Bool. +// Literals: Int / Str / Num / Bool. A local picks up its initializer's type. :siret let i = 1 let f = 1.5 let s = "str" let b = true //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let i⁰: ?, f⁰: ?, s⁰: ?, b⁰: ?; set i⁰ = 1; set f⁰ = 1.5; set s⁰ = "str"; set b⁰ = true; end +//│ let i⁰: Int, f⁰: Num, s⁰: Str, b⁰: Bool; +//│ set i⁰ = 1; +//│ set f⁰ = 1.5; +//│ set s⁰ = "str"; +//│ set b⁰ = true; +//│ end + + +// Local reassignment join: a local keeps its type while assignments agree, and +// widens to `Object` on a type conflict. A declared-then-assigned local (no +// initializer) is still typed from its assignment. +:siret +let consistent = 1 +set consistent = 2 +let conflicting = 1 +set conflicting = "two" +let declared +set declared = 3 +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let consistent⁰: Int, conflicting⁰: Object, declared⁰: Int; +//│ set consistent⁰ = 1; +//│ set consistent⁰ = 2; +//│ set conflicting⁰ = 1; +//│ set conflicting⁰ = "two"; +//│ set declared⁰ = 3; +//│ return runtime⁰.Unit⁰ // Tuple -> Array; record -> Object (the `NoSymbol` top case). @@ -19,7 +44,7 @@ let b = true let t = [i, s] let r = {x: i, y: s} //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let t⁰: ?, r⁰: ?, x⁰: ?, y⁰: ?; +//│ let t⁰: Array, r⁰: Object, x⁰: Int, y⁰: Str; //│ set t⁰ = [i⁰, s⁰]; //│ set x⁰ = i⁰; //│ set y⁰ = s⁰; @@ -32,12 +57,12 @@ let r = {x: i, y: s} class Foo(x: Int) let foo = new Foo(123) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let Foo⁰, foo⁰: ?; -//│ define Foo⁰ as class Foo² { +//│ let Foo⁰, foo⁰: Foo¹; +//│ define Foo⁰ as class Foo¹ { //│ private val x¹: Int; -//│ constructor Foo¹(x: Int) { set Foo².this.x¹ = x; end } +//│ constructor Foo²(x: Int) { set Foo¹.this.x¹ = x; end } //│ }; -//│ set foo⁰ = new Foo²(123); +//│ set foo⁰ = new Foo¹(123); //│ end @@ -46,12 +71,12 @@ let foo = new Foo(123) class Foo(val x: Str) let foo = new Foo(123) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let Foo³, foo¹: ?; -//│ define Foo³ as class Foo⁵ { +//│ let Foo³, foo¹: Foo⁴; +//│ define Foo³ as class Foo⁴ { //│ val x²: Str; -//│ constructor Foo⁴(x: Str) { define x² as val x³: Str = x; end } +//│ constructor Foo⁵(x: Str) { define x² as val x³: Str = x; end } //│ }; -//│ set foo¹ = new Foo⁵(123); +//│ set foo¹ = new Foo⁴(123); //│ end @@ -62,8 +87,8 @@ let made = mk(1) let sel = made.x //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— //│ let mk⁰, made⁰: ?, sel⁰: ?, selRes: ?; -//│ define mk⁰ as fun mk¹(a: ?): Foo⁵ { -//│ return new Foo⁵(a) +//│ define mk⁰ as fun mk¹(a: ?): Foo⁴ { +//│ return new Foo⁴(a) //│ }; //│ set made⁰ = mk¹(1); //│ set selRes = made⁰.x﹖; @@ -85,6 +110,26 @@ fun add(x: Int, y: Int) = x + y //│ let add⁰; define add⁰ as fun add¹(x: Int, y: Int): Int { return +⁰(x, y) }; end +// A reassigned parameter joins with its annotation: a conflicting `set` widens +// the annotated type (and the inferred return) to `Object`; a consistent `set` +// keeps it. +:siret +fun widen(x: Int) = + set x = "s" + x +fun keep(x: Int) = + set x = 5 + x +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let widen⁰, keep⁰; +//│ define widen⁰ as fun widen¹(x: Object): Object { +//│ set x = "s"; +//│ return x +//│ }; +//│ define keep⁰ as fun keep¹(x: Int): Int { set x = 5; return x }; +//│ end + + // Return-type annotations are rendered at the `fun ...: T` definition site, // for both primitive and class return types. :siret @@ -95,7 +140,7 @@ fun makeFoo(n: Int): Foo = new Foo(n) //│ define addUp⁰ as fun addUp¹(x: Int, y: Int): Int { //│ return +⁰(x, y) //│ }; -//│ define makeFoo⁰ as fun makeFoo¹(n: Int): Foo⁵ { return new Foo⁵(n) }; +//│ define makeFoo⁰ as fun makeFoo¹(n: Int): Foo⁴ { return new Foo⁴(n) }; //│ end @@ -112,8 +157,8 @@ fun ident(a) = a //│ define lt⁰ as fun lt¹(a: ?, b: ?): Bool { //│ return <⁰(a, b) //│ }; -//│ define mkFoo⁰ as fun mkFoo¹(a: ?): Foo⁵ { -//│ return new Foo⁵(a) +//│ define mkFoo⁰ as fun mkFoo¹(a: ?): Foo⁴ { +//│ return new Foo⁴(a) //│ }; //│ define ident⁰ as fun ident¹(a: ?): ? { return a }; //│ end @@ -143,9 +188,9 @@ class Calc() with fun id(a) = a id(new Foo(1)) //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let id⁰, tmp: Foo⁵; +//│ let id⁰, tmp: Foo⁴; //│ define id⁰ as fun id¹(a: ?): ? { return a }; -//│ set tmp = new Foo⁵(1); +//│ set tmp = new Foo⁴(1); //│ return id¹(tmp) //│ ——————————————| Optimized IR |—————————————————————————————————————————————————————————————————————— -//│ let id⁰; define id⁰ as fun id¹(a: ?): ? { return a }; return new Foo⁵(1) +//│ let id⁰; define id⁰ as fun id¹(a: ?): ? { return a }; return new Foo⁴(1) From 4f63bafde6b3fd02c1a0ba2d3e32fa2742bf128d Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 12:08:33 +0800 Subject: [PATCH 30/48] codegen: Propagate call types --- hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala | 6 ++++++ hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls | 5 +++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 31ea7d6027..22911116a3 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -1028,6 +1028,12 @@ sealed abstract class Result extends AutoLocated, HasErasedType: case Value.Lit(lit) => S(lit.erasedType) case Call(Value.SimpleRef(bs: BuiltinSymbol), argss) => bs.resultErasedType(argss.head.map(_.value.erasedType)) + case Call(fun, _) => fun.targetSymbol match + // * A call's result is the callee's return type, recovered from its `FuncRef` (Phase F.2). + case S(ts: TermSymbol) => ts.erasedType match + case S(ErasedType.FuncRef(sig)) => sig.flatMap(_._2) + case _ => N + case _ => N case _ => N /* mayRaiseEffects indicates whether this call may raise effect (algebraic effect), diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index e864fdf8a7..75b8f4b193 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -80,13 +80,14 @@ let foo = new Foo(123) //│ end -// Call return / field selection +// Call return is inferred from the callee's return type; field selection stays +// `: ?` (resolved in a later phase). :siret fun mk(a) = new Foo(a) let made = mk(1) let sel = made.x //│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— -//│ let mk⁰, made⁰: ?, sel⁰: ?, selRes: ?; +//│ let mk⁰, made⁰: Foo⁴, sel⁰: ?, selRes: ?; //│ define mk⁰ as fun mk¹(a: ?): Foo⁴ { //│ return new Foo⁴(a) //│ }; From f3bca96daa911e703b016ef48a649f899ffcb6f9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 12:19:36 +0800 Subject: [PATCH 31/48] codegen: Propagate selections --- .../src/main/scala/hkmc2/codegen/Block.scala | 6 ++++++ .../src/test/mlscript/codegen/ErasedType.mls | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 22911116a3..de74ba4f3c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -1034,6 +1034,12 @@ sealed abstract class Result extends AutoLocated, HasErasedType: case S(ErasedType.FuncRef(sig)) => sig.flatMap(_._2) case _ => N case _ => N + // * A resolved selection has the type of the member it refers to (e.g. `this.field`); an + // * unresolved selection (dynamic field access) stays unknown. + case sel @ Select(_, _) => sel.symbol match + case S(ts: TermSymbol) => ts.erasedType + case S(d: (ClassSymbol | ModuleOrObjectSymbol | TypeAliasSymbol)) => d.erasedType + case _ => N case _ => N /* mayRaiseEffects indicates whether this call may raise effect (algebraic effect), diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index 75b8f4b193..1b784fd2d7 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -182,6 +182,24 @@ class Calc() with //│ end +// A selection resolved to a member (`this.v`) carries that member's type, so a +// field getter's return is inferred. +:siret +class Box(val v: Int) with + fun get() = v +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let Box⁰; +//│ define Box⁰ as class Box² { +//│ val v⁰: Int; +//│ constructor Box¹(v: Int) { +//│ define v⁰ as val v¹: Int = v; +//│ end +//│ } +//│ method get⁰ = fun get¹(): Int { return Box².this.v¹ } +//│ }; +//│ end + + // Optimized IR with erased types: surfaces where a type relaxes to `: ?` // through a transform (the relaxation gap `:soir` makes visible). :soir From 72a334b1aa2a189bc654d62d02c80d9529448105 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 12:30:46 +0800 Subject: [PATCH 32/48] codegen: Infer rest params --- .../shared/src/main/scala/hkmc2/codegen/Lowering.scala | 10 +++++++++- hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls | 8 ++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 16f0cabb3b..11bcda437f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -291,7 +291,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ) case _ => _defn reportAnnotations(defn, defn.extraAnnotations) - (defn.paramsOpt.iterator ++ defn.auxParams.iterator).flatMap(_.params).foreach(refineClassParam) + (defn.paramsOpt.iterator ++ defn.auxParams.iterator).foreach: pl => + pl.params.foreach(refineClassParam) + pl.restParam.foreach(refineRestParam) val bufferableAnnots = defn.annotations.flatMap: case Annot.Trm(trm: SynthSel) => if trm.sym.contains(ctx.builtins.annotations.buffered) then @@ -1401,6 +1403,11 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case bms: BlockMemberSymbol => bms.tsym.foreach(_.refineErasedType(et)) case _ => + /** A rest parameter always binds an array of the collected arguments, regardless of any + * element annotation, so its erased type is always `Array`. */ + private def refineRestParam(p: Param): Unit = + p.sym.refineErasedType(ErasedType.Primitive(PrimitiveType.Array)) + /** Infers a function's return type from its body's `return`s, but only when every * `return` agrees on a single known type (a conservative equal-or-`N` join); a body * with conflicting or unknown returns stays `N`. Nested function/lambda bodies are not @@ -1439,6 +1446,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): paramLists.foreach: pl => pl.params.foreach: p => p.sign.flatMap(eraseSign).foreach(p.sym.refineErasedType) + pl.restParam.foreach(refineRestParam) val scopedBody = inScopedBlock(returnedTerm(bodyTerm)) (paramLists, scopedBody) diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index 1b784fd2d7..60d7b7773f 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -145,6 +145,14 @@ fun makeFoo(n: Int): Foo = new Foo(n) //│ end +// A rest parameter binds an array of the collected arguments, so it is typed +// `Array` regardless of any element annotation. +:siret +fun rest(a, ...xs) = xs +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let rest⁰; define rest⁰ as fun rest¹(a: ?, ...xs: Array): Array { return xs }; end + + // Inferred returns from the body terminal: a builtin comparison yields `Bool` // (regardless of operand types), an instantiation yields the class, and an // unconstrained identity return stays `: ?`. (Equality `==` lowers to a From 5e49f52ba68a9a6ced2b4e2d9c81f9c5a3a614a1 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 12:48:12 +0800 Subject: [PATCH 33/48] codegen: Infer type of lifter capture symbols --- .../src/main/scala/hkmc2/codegen/Lifter.scala | 13 ++++++-- .../src/test/mlscript/codegen/ErasedType.mls | 33 +++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index b959e07cd6..ba992af5d6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -678,9 +678,14 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): protected final def addExtraSyms(b: Block, captureSym: LocalVarSymbol, objSyms: Iterable[ScopedSymbol]): Block = if hasCapture then + val inst = instantiateCapture + // * The capture symbol holds an instance of the capture class, so it takes that type. + captureSym match + case s: HasRefinableErasedType => s.erasedType = inst.erasedType + case _ => Scoped( objSyms.toSet + captureSym, - Assign(captureSym, instantiateCapture, b) + Assign(captureSym, inst, b) ) else Scoped(objSyms.toSet, b) @@ -899,7 +904,11 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val (liftedMtds, extras) = mtds.map(liftNestedScopes).unzip(using l => (l.liftedDefn, l.extraDefns)) LifterResult(liftedMtds, extras.flatten) protected final def initCaptureField(b: Block): Block = - if hasCapture then AssignField(sym.asThis, captureSym.id, instantiateCapture, b)(S(captureSym)) + if hasCapture then + val inst = instantiateCapture + // * The capture field holds an instance of the capture class, so it takes that type. + captureSym.erasedType = inst.erasedType + AssignField(sym.asThis, captureSym.id, inst, b)(S(captureSym)) else b // some helpers diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index 60d7b7773f..185b13c80a 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -221,3 +221,36 @@ id(new Foo(1)) //│ return id¹(tmp) //│ ——————————————| Optimized IR |—————————————————————————————————————————————————————————————————————— //│ let id⁰; define id⁰ as fun id¹(a: ?): ? { return a }; return new Foo⁴(1) + + +// Closure capture (`:lift`): the symbol holding the capture object is typed as +// the generated capture class. +:siret +:lift +fun counter() = + let n = 0 + fun bump() = set n = n + 1 + bump() + n +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let bump⁰, counter⁰, Capture$scope0⁰; +//│ define Capture$scope0⁰ as class Capture$scope0² { +//│ constructor Capture$scope0¹(n$0: ?) { +//│ define n$0⁰ as val n$0¹: ? = n$0; +//│ end +//│ } +//│ }; +//│ define bump⁰ as fun bump¹(scope0$cap: ?): Unit⁰ { +//│ let tmp: ?; +//│ set tmp = +⁰(scope0$cap.n$0¹, 1); +//│ set scope0$cap.n$0¹ = tmp; +//│ return runtime⁰.Unit⁰ +//│ }; +//│ define counter⁰ as fun counter¹(): Object { +//│ let n: Object, scope0$cap: Capture$scope0²; +//│ set scope0$cap = new mut Capture$scope0²(n); +//│ set scope0$cap.n$0¹ = 0; +//│ do bump¹(scope0$cap); +//│ return scope0$cap.n$0¹ +//│ }; +//│ end From ca6dd9db3f03b801ba0d82994e96c3046e03b335 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 13:02:38 +0800 Subject: [PATCH 34/48] Rename refine -> populate --- .../src/main/scala/hkmc2/codegen/Block.scala | 6 ++-- .../src/main/scala/hkmc2/codegen/Lifter.scala | 2 +- .../main/scala/hkmc2/codegen/Lowering.scala | 30 +++++++++---------- .../main/scala/hkmc2/semantics/Symbol.scala | 6 ++-- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index de74ba4f3c..2194756d4b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -894,12 +894,12 @@ trait HasErasedType: def erasedType_! : ErasedType = erasedType.getOrElse(ErasedType.ObjectRef) /** A [[`HasErasedType`]] that can have its erased type refined post-construction. */ -trait HasRefinableErasedType extends HasErasedType: +trait HasMutableErasedType extends HasErasedType: // Implementation Note: Provided for overriding classes to implement `erasedType` directly as an `override var` def erasedType_=(newType: Opt[ErasedType]): Unit /** Refines the erased type, or raises a soft assertion if the type was already previously refined. */ - def refineErasedType(newType: ErasedType)(using Line, FileName, Raise): Unit = + def populateErasedType(newType: ErasedType)(using Line, FileName, Raise): Unit = // TODO(Derppening): Restore `erasedType.isEmpty` once JS sanitization is converted into a pass softAssert(erasedType.forall(_ == newType), s"Cannot refine already-refined erased type $erasedType to $newType") if erasedType.isEmpty then erasedType = S(newType) @@ -909,7 +909,7 @@ trait HasRefinableErasedType extends HasErasedType: /** * Observes an assignment of `observed` to this (possibly reassignable) symbol, joining it with - * any previously known type. Unlike [[`refineErasedType`]], which expects a single authoritative + * any previously known type. Unlike [[`populateErasedType`]], which expects a single authoritative * write, this tolerates multiple assignments of differing types: the first observation of an * otherwise-unknown symbol is recorded exactly (so a single assignment of an unknown type keeps * `N`), and any later disagreement — including with a type already set by an annotation — widens to diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index ba992af5d6..074ed6238f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -681,7 +681,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val inst = instantiateCapture // * The capture symbol holds an instance of the capture class, so it takes that type. captureSym match - case s: HasRefinableErasedType => s.erasedType = inst.erasedType + case s: HasMutableErasedType => s.erasedType = inst.erasedType case _ => Scoped( objSyms.toSet + captureSym, diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 11bcda437f..151979091a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -249,7 +249,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): blockImpl(stats, res)))(using LoweringCtx.nestFunc) case syntax.Fun => val (paramLists, bodyBlock) = setupFunctionOrByNameDef(td.params, bod, S(td.sym.nme)) - refineFunDefnType(td.tsym, paramLists, td.sign, bodyBlock) + populateFunDefnType(td.tsym, paramLists, td.sign, bodyBlock) val cfgOverride = td.extraAnnotations.collectFirst: case Annot.Config(modify) => modify(config) Define(FunDefn(td.owner, td.sym, td.tsym, paramLists, bodyBlock)(cfgOverride, td.annotations), @@ -292,8 +292,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case _ => _defn reportAnnotations(defn, defn.extraAnnotations) (defn.paramsOpt.iterator ++ defn.auxParams.iterator).foreach: pl => - pl.params.foreach(refineClassParam) - pl.restParam.foreach(refineRestParam) + pl.params.foreach(populateClassParam) + pl.restParam.foreach(populateRestParam) val bufferableAnnots = defn.annotations.flatMap: case Annot.Trm(trm: SynthSel) => if trm.sym.contains(ctx.builtins.annotations.buffered) then @@ -547,7 +547,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): /** Joins the erased type of a local variable's RHS into its (refinable) symbol. Only `VarSymbol`s * are refinable; compiler-generated `TempSymbol`s carry their type from creation, so they are skipped. */ private def observeLocalErasedType(sym: LocalVarSymbol, rhs: Result): Unit = sym match - case sym: HasRefinableErasedType => sym.observeErasedType(rhs.erasedType) + case sym: HasMutableErasedType => sym.observeErasedType(rhs.erasedType) case _ => private def defineSymbol(sym: Symbol, rhs: Result, rest: Block)(using LoweringCtx): Block = @@ -1220,7 +1220,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): .flatMap: td => td.body.map: bod => val (paramLists, bodyBlock) = setupFunctionDef(td.params, bod, S(td.sym.nme)) - refineFunDefnType(td.tsym, paramLists, td.sign, bodyBlock) + populateFunDefnType(td.tsym, paramLists, td.sign, bodyBlock) reportAnnotations(td, td.extraAnnotations) val cfgOverride = td.extraAnnotations.collectFirst: case Annot.Config(modify) => modify(config) @@ -1395,18 +1395,18 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): private def eraseSign(sign: Term): Opt[ErasedType] = sign.symbol.flatMap(_.asClsOrMod).map(sym => ErasedType.fromClsLikeSymbol(sym, rsc = false)) - private def refineClassParam(p: Param): Unit = + private def populateClassParam(p: Param): Unit = p.sign.flatMap(eraseSign).foreach: et => - p.sym.refineErasedType(et) + p.sym.populateErasedType(et) p.fldSym.foreach: - case fld: TermSymbol => fld.refineErasedType(et) - case bms: BlockMemberSymbol => bms.tsym.foreach(_.refineErasedType(et)) + case fld: TermSymbol => fld.populateErasedType(et) + case bms: BlockMemberSymbol => bms.tsym.foreach(_.populateErasedType(et)) case _ => /** A rest parameter always binds an array of the collected arguments, regardless of any * element annotation, so its erased type is always `Array`. */ - private def refineRestParam(p: Param): Unit = - p.sym.refineErasedType(ErasedType.Primitive(PrimitiveType.Array)) + private def populateRestParam(p: Param): Unit = + p.sym.populateErasedType(ErasedType.Primitive(PrimitiveType.Array)) /** Infers a function's return type from its body's `return`s, but only when every * `return` agrees on a single known type (a conservative equal-or-`N` join); a body @@ -1430,12 +1430,12 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): * This is a derived type that may be recomputed when a definition is lowered more than once * (e.g. under `:lift`/`:effectHandlers`); since a later pass can infer a different return type, * the result is recorded only the first time and left untouched afterwards (rather than going - * through the asserting [[`refineErasedType`]], which is meant for one-shot annotations). + * through the asserting [[`populateErasedType`]], which is meant for one-shot annotations). * * The parameters of curried functions are flattened into a single list: this is lossy * for the arrow shape but does not affect the rendered return type, the only consumer * today. */ - private def refineFunDefnType(tsym: TermSymbol, paramLists: Ls[ParamList], sign: Opt[Term], body: Block): Unit = + private def populateFunDefnType(tsym: TermSymbol, paramLists: Ls[ParamList], sign: Opt[Term], body: Block): Unit = if tsym.erasedType.isEmpty then val params = paramLists.flatMap(_.params).map(_.sym.erasedType) val ret = sign.flatMap(eraseSign) orElse inferReturn(body) @@ -1445,8 +1445,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): (using LoweringCtx): (List[ParamList], Block) = paramLists.foreach: pl => pl.params.foreach: p => - p.sign.flatMap(eraseSign).foreach(p.sym.refineErasedType) - pl.restParam.foreach(refineRestParam) + p.sign.flatMap(eraseSign).foreach(p.sym.populateErasedType) + pl.restParam.foreach(populateRestParam) val scopedBody = inScopedBlock(returnedTerm(bodyTerm)) (paramLists, scopedBody) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index 47a33b2f15..74609f2d9f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -12,7 +12,7 @@ import Elaborator.State import Tree.Ident import hkmc2.codegen.{ErasedType, PrimitiveType, HasErasedType, erasedType} import hkmc2.utils.SymbolSubst -import hkmc2.codegen.HasRefinableErasedType +import hkmc2.codegen.HasMutableErasedType sealed abstract class MaybeSymbol: @@ -251,7 +251,7 @@ class InstSymbol(val origin: Symbol)(using State) extends LocalSymbol: class VarSymbol(val id: Ident, override var erasedType: Opt[ErasedType])(using State) extends LocalVarSymbol(id.name) - with HasRefinableErasedType + with HasMutableErasedType with NamedSymbol: val name: Str = id.name override def toLoc: Opt[Loc] = id.toLoc @@ -371,7 +371,7 @@ sealed abstract class MemberSymbol(using State) extends Symbol: class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.Ident, override var erasedType: Opt[ErasedType])(using State) extends MemberSymbol with DefinitionSymbol[TermDefinition] - with HasRefinableErasedType + with HasMutableErasedType with NamedSymbol: def nme: Str = id.name def name: Str = nme From 4427c921e9ea7de82a0bc632da4873e283e4bf6c Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 16:17:26 +0800 Subject: [PATCH 35/48] Rewrite Scaladoc, add TODOs --- .../src/main/scala/hkmc2/codegen/Block.scala | 30 ++++++------ .../main/scala/hkmc2/codegen/Lowering.scala | 49 +++++++++---------- .../main/scala/hkmc2/codegen/Printer.scala | 6 +-- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 2 +- .../main/scala/hkmc2/semantics/Symbol.scala | 8 +-- 5 files changed, 46 insertions(+), 49 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index d9a7aabd63..f436f227b5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -885,10 +885,10 @@ enum ErasedType: /** The symbol for this erased type. */ def sym(using Ctx, State): ClassLikeSymbol = this match - case AnyRef(_, csym: ClassLikeSymbol) => csym - case AnyRef(_, _: NoSymbol) => ctx.builtins.Object - case FuncRef(_) => ctx.builtins.Function - case Primitive(prim) => prim.sym + case AnyRef(_, csym: ClassLikeSymbol) => csym + case AnyRef(_, _: NoSymbol) => ctx.builtins.Object + case FuncRef(_) => ctx.builtins.Function + case Primitive(prim) => prim.sym /** Trait representing a Block IR element that has an [[`ErasedType`]]. */ trait HasErasedType: @@ -903,25 +903,25 @@ trait HasMutableErasedType extends HasErasedType: // Implementation Note: Provided for overriding classes to implement `erasedType` directly as an `override var` def erasedType_=(newType: Opt[ErasedType]): Unit - /** Refines the erased type, or raises a soft assertion if the type was already previously refined. */ + /** Populates the erased type, or raises a soft assertion if the type was already populated. */ def populateErasedType(newType: ErasedType)(using Line, FileName, Raise): Unit = - // TODO(Derppening): Restore `erasedType.isEmpty` once JS sanitization is converted into a pass + // TODO(Derppening): Restore `erasedType.isEmpty` once JS sanitization is converted into a pass, allowing us to + // only lower each program once softAssert(erasedType.forall(_ == newType), s"Cannot refine already-refined erased type $erasedType to $newType") if erasedType.isEmpty then erasedType = S(newType) - /** Tracks whether [[`observeErasedType`]] has seen at least one assignment to this symbol. */ + /** Tracks whether [[`observeErasedTypeAssign`]] has seen at least one assignment to this symbol. */ private var erasedTypeObserved: Bool = false /** - * Observes an assignment of `observed` to this (possibly reassignable) symbol, joining it with - * any previously known type. Unlike [[`populateErasedType`]], which expects a single authoritative - * write, this tolerates multiple assignments of differing types: the first observation of an - * otherwise-unknown symbol is recorded exactly (so a single assignment of an unknown type keeps - * `N`), and any later disagreement — including with a type already set by an annotation — widens to - * the top type [[`ErasedType.ObjectRef`]]. The join is monotone (`N` → known → top), so it never - * re-narrows; a bare `N` therefore means "never observed". + * Observes an assignment of a value with type `observed` to this symbol. + * + * Some symbols (e.g. [[`LocalVarSymbol`]]) can be assigned to multiple times with different types of values. This + * method tracks multiple assignments by coercing the type of the symbol to the top type if a subsequent assignemnt + * does not share the same type as the previously populated type. */ - def observeErasedType(observed: Opt[ErasedType]): Unit = + // TODO(Derppening): We should probably coerce to the common supertype rather than directly to the top type + def observeErasedTypeAssign(observed: Opt[ErasedType]): Unit = if !erasedTypeObserved && erasedType.isEmpty then erasedType = observed else if erasedType != observed then diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 40a5176415..96dd97aa29 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -287,8 +287,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case _ => _defn reportAnnotations(defn, defn.extraAnnotations) (defn.paramsOpt.iterator ++ defn.auxParams.iterator).foreach: pl => - pl.params.foreach(populateClassParam) - pl.restParam.foreach(populateRestParam) + pl.params.foreach(populateClassParamErasedType) + pl.restParam.foreach(populateRestParamErasedType) val bufferableAnnots = defn.annotations.flatMap: case Annot.Trm(trm: SynthSel) => if trm.sym.contains(ctx.builtins.annotations.buffered) then @@ -562,10 +562,12 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Assign(sym, rhs, rest) case sym => nope - /** Joins the erased type of a local variable's RHS into its (refinable) symbol. Only `VarSymbol`s - * are refinable; compiler-generated `TempSymbol`s carry their type from creation, so they are skipped. */ + /** Observes an assignment of `rhs` to `sym`, populating or updating its erased type where applicable. + * + * See [[`HasMutableErasedType.observeErasedTypeAssign`]]. + */ private def observeLocalErasedType(sym: LocalVarSymbol, rhs: Result): Unit = sym match - case sym: HasMutableErasedType => sym.observeErasedType(rhs.erasedType) + case sym: HasMutableErasedType => sym.observeErasedTypeAssign(rhs.erasedType) case _ => private def defineSymbol(sym: Symbol, rhs: Result, rest: Block)(using LoweringCtx): Block = @@ -1411,8 +1413,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): /** Erases a type-annotated term to an [[`ErasedType`]]. */ private def eraseSign(sign: Term): Opt[ErasedType] = sign.symbol.flatMap(_.asClsOrMod).map(sym => ErasedType.fromClsLikeSymbol(sym, rsc = false)) - - private def populateClassParam(p: Param): Unit = + + /** Populates the [[`ErasedType`]] of a class parameter. */ + private def populateClassParamErasedType(p: Param): Unit = p.sign.flatMap(eraseSign).foreach: et => p.sym.populateErasedType(et) p.fldSym.foreach: @@ -1420,15 +1423,13 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case bms: BlockMemberSymbol => bms.tsym.foreach(_.populateErasedType(et)) case _ => - /** A rest parameter always binds an array of the collected arguments, regardless of any - * element annotation, so its erased type is always `Array`. */ - private def populateRestParam(p: Param): Unit = + /** Populates the [[`ErasedType`]] of the `rest` parameter. */ + private def populateRestParamErasedType(p: Param): Unit = p.sym.populateErasedType(ErasedType.Primitive(PrimitiveType.Array)) - /** Infers a function's return type from its body's `return`s, but only when every - * `return` agrees on a single known type (a conservative equal-or-`N` join); a body - * with conflicting or unknown returns stays `N`. Nested function/lambda bodies are not - * descended into. */ + /** Infers the [[`ErasedType`]] of a function's return type, by inspecting the erased type of all return values. */ + // TODO(Derppening): This should return `N` only if any return value is `N` - Conflicting known return types should + // be joined to `AnyRef` (or common ancestor) private def inferReturn(body: Block): Opt[ErasedType] = var rets: Ls[Opt[ErasedType]] = Nil new BlockTraverserShallow: @@ -1440,18 +1441,12 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case head :: tail if tail.forall(_ == head) => head case _ => N - /** Populates a function definition symbol's erased type with a [[`ErasedType.FuncRef`]] - * derived from its (already-refined) parameter symbols and return type. The return type - * comes from the explicit annotation when present, otherwise it is inferred from the body. - * - * This is a derived type that may be recomputed when a definition is lowered more than once - * (e.g. under `:lift`/`:effectHandlers`); since a later pass can infer a different return type, - * the result is recorded only the first time and left untouched afterwards (rather than going - * through the asserting [[`populateErasedType`]], which is meant for one-shot annotations). - * - * The parameters of curried functions are flattened into a single list: this is lossy - * for the arrow shape but does not affect the rendered return type, the only consumer - * today. */ + /** Populates a function definition symbol's erased type with a [[`ErasedType.FuncRef`]] derived from its + * (already-populated) parameter symbols and return type. + * + * The return type comes from the explicit annotation when present, otherwise it is inferred from the body. + */ + // TODO(Derppening): Parameters of curried functions are currently flattened - Should we preserve the curried shape? private def populateFunDefnType(tsym: TermSymbol, paramLists: Ls[ParamList], sign: Opt[Term], body: Block): Unit = if tsym.erasedType.isEmpty then val params = paramLists.flatMap(_.params).map(_.sym.erasedType) @@ -1463,7 +1458,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): paramLists.foreach: pl => pl.params.foreach: p => p.sign.flatMap(eraseSign).foreach(p.sym.populateErasedType) - pl.restParam.foreach(populateRestParam) + pl.restParam.foreach(populateRestParamErasedType) val scopedBody = inScopedBlock(returnedTerm(bodyTerm)) (paramLists, scopedBody) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index e97c2686fb..129c760d3a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -37,13 +37,13 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case ErasedType.FuncRef(S(params -> ret)) => doc"(${params.map(_.fold(doc"?")(print)).mkDocument(sep = doc", ")}) => ${ret.fold(doc"?")(print)}" case ErasedType.Primitive(prim) => doc"${prim.toString}" - + + /** Renders the type annotation for a symbol with an [[`ErasedType`]]. */ def erasedTypeAnnot(x: HasErasedType)(using Scope): Document = if !summon[ShowCfg].showErasedTypes then doc"" else doc": ${x.erasedType.fold(doc"?")(print)}" - /** Renders a function's return type, projected from the `FuncRef` carried by its - * definition symbol. Nothing is rendered when the symbol carries no `FuncRef`. */ + /** Renders a function's return type, projected from the `FuncRef` carried by its definition symbol. */ def returnTypeAnnot(dSym: TermSymbol)(using Scope): Document = if !summon[ShowCfg].showErasedTypes then doc"" else dSym.erasedType match diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 0e5281b554..1cf70befa3 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -40,7 +40,7 @@ extension (et: ErasedType) extension (sym: ValueSymbol) /** The Wasm reference type a *local* slot for `sym` should be declared with. * - * Use [[FunctionCtx.slowRefType]] for parameter slots, which handles `anyref` widening due to virtual dispatch + * Use [[`FunctionCtx.slotRefType`]] for parameter slots, which handles `anyref` widening due to virtual dispatch * calling conventions. */ private[text] def localRefType(using Ctx, State): RefType = diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index 74609f2d9f..0e5f7e6c96 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -286,9 +286,11 @@ class BuiltinSymbol case _ => Bot semantics.flow.Producer.Typ(typ) - /** The result [[`ErasedType`]] of applying this builtin operator to operands of the - * given erased types, or `N` if this symbol is not a recognized operator. Context-free; - * surfaces the result-type knowledge already implicit in `BlockSimplifier.builtinEval`. */ + /** The result [[`ErasedType`]] of applying this builtin operator to operands of the given erased types. + * + * Returns `N` if the operator is not recognized or the arguments to the operator is not sufficient to determine the + * result type. + */ def resultErasedType(args: Ls[Opt[ErasedType]]): Opt[ErasedType] = import ErasedType.Primitive def isStr(t: Opt[ErasedType]) = t.contains(Primitive(PrimitiveType.Str)) From 68b8c9efd092f285c1082549cf17a56e229a8f1b Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 16:26:31 +0800 Subject: [PATCH 36/48] codegen: Split `HasMutableErasedType` --- hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala | 7 +++++-- hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala | 2 +- hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala | 4 ++-- hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala | 7 +++---- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index f436f227b5..32d677abf8 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -898,8 +898,8 @@ trait HasErasedType: /** Similar to `erasedType`, but coerces to the top type if the specific erased type is not known. */ def erasedType_! : ErasedType = erasedType.getOrElse(ErasedType.ObjectRef) -/** A [[`HasErasedType`]] that can have its erased type refined post-construction. */ -trait HasMutableErasedType extends HasErasedType: +/** A [[`HasErasedType`]] whose erased type can be populated exactly once post-construction. */ +trait HasOnceMutableErasedType extends HasErasedType: // Implementation Note: Provided for overriding classes to implement `erasedType` directly as an `override var` def erasedType_=(newType: Opt[ErasedType]): Unit @@ -910,6 +910,9 @@ trait HasMutableErasedType extends HasErasedType: softAssert(erasedType.forall(_ == newType), s"Cannot refine already-refined erased type $erasedType to $newType") if erasedType.isEmpty then erasedType = S(newType) +/** A [[`HasOnceMutableErasedType`]] whose erased type can additionally be assigned multiple times, joining the + * observed types into the top type on disagreement. */ +trait HasManyMutableErasedType extends HasOnceMutableErasedType: /** Tracks whether [[`observeErasedTypeAssign`]] has seen at least one assignment to this symbol. */ private var erasedTypeObserved: Bool = false diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index 074ed6238f..8575e237bc 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -681,7 +681,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val inst = instantiateCapture // * The capture symbol holds an instance of the capture class, so it takes that type. captureSym match - case s: HasMutableErasedType => s.erasedType = inst.erasedType + case s: HasOnceMutableErasedType => s.erasedType = inst.erasedType case _ => Scoped( objSyms.toSet + captureSym, diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 96dd97aa29..38e5cbba5e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -564,10 +564,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): /** Observes an assignment of `rhs` to `sym`, populating or updating its erased type where applicable. * - * See [[`HasMutableErasedType.observeErasedTypeAssign`]]. + * See [[`HasManyMutableErasedType.observeErasedTypeAssign`]]. */ private def observeLocalErasedType(sym: LocalVarSymbol, rhs: Result): Unit = sym match - case sym: HasMutableErasedType => sym.observeErasedTypeAssign(rhs.erasedType) + case sym: HasManyMutableErasedType => sym.observeErasedTypeAssign(rhs.erasedType) case _ => private def defineSymbol(sym: Symbol, rhs: Result, rest: Block)(using LoweringCtx): Block = diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index 0e5f7e6c96..75abc73700 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -10,9 +10,8 @@ import hkmc2.utils.* import Elaborator.State import Tree.Ident -import hkmc2.codegen.{ErasedType, PrimitiveType, HasErasedType, erasedType} +import hkmc2.codegen.{ErasedType, HasErasedType, HasManyMutableErasedType, HasOnceMutableErasedType, PrimitiveType, erasedType} import hkmc2.utils.SymbolSubst -import hkmc2.codegen.HasMutableErasedType sealed abstract class MaybeSymbol: @@ -251,7 +250,7 @@ class InstSymbol(val origin: Symbol)(using State) extends LocalSymbol: class VarSymbol(val id: Ident, override var erasedType: Opt[ErasedType])(using State) extends LocalVarSymbol(id.name) - with HasMutableErasedType + with HasManyMutableErasedType with NamedSymbol: val name: Str = id.name override def toLoc: Opt[Loc] = id.toLoc @@ -373,7 +372,7 @@ sealed abstract class MemberSymbol(using State) extends Symbol: class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.Ident, override var erasedType: Opt[ErasedType])(using State) extends MemberSymbol with DefinitionSymbol[TermDefinition] - with HasMutableErasedType + with HasOnceMutableErasedType with NamedSymbol: def nme: Str = id.name def name: Str = nme From 8c380bd77a2b9065ee81a22cbe47431c2d8bf8b5 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 16:45:06 +0800 Subject: [PATCH 37/48] codegen: Make `N: Opt[ErasedType]` mean an unknown type --- .../src/main/scala/hkmc2/codegen/Block.scala | 33 ++++++++++++++----- .../main/scala/hkmc2/codegen/Lowering.scala | 10 +++--- .../src/test/mlscript/codegen/ErasedType.mls | 4 +-- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 32d677abf8..52dc3b48e2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -865,6 +865,18 @@ object ErasedType: case _ if csym === ctx.builtins.Object => ObjectRef case S(prim) => ErasedType.Primitive(prim) case _ => ErasedType.AnyRef(rsc, csym) + + /** Joins two possibly-unknown erased types and returning the result. + * + * - If both sides are known and equal, returns that known type. + * - If both sides are known and unequal, returns the top type [[`ErasedType.ObjectRef`]]. + * - If either side is unknown (`N`), returns `N` - Representing a possibly-unknown type. + */ + // TODO(Derppening): Widen conflicting knowns to their common ancestor once erased types track subtyping, + // rather than going directly to the top type. + def join(lhs: Opt[ErasedType], rhs: Opt[ErasedType]): Opt[ErasedType] = (lhs, rhs) match + case (N, _) | (_, N) => N + case (S(l), S(r)) => if l == r then S(l) else S(ObjectRef) /** A generics-erased type of the Block IR. */ enum ErasedType: @@ -910,25 +922,28 @@ trait HasOnceMutableErasedType extends HasErasedType: softAssert(erasedType.forall(_ == newType), s"Cannot refine already-refined erased type $erasedType to $newType") if erasedType.isEmpty then erasedType = S(newType) -/** A [[`HasOnceMutableErasedType`]] whose erased type can additionally be assigned multiple times, joining the - * observed types into the top type on disagreement. */ +/** A [[`HasOnceMutableErasedType`]] whose erased type can additionally be assigned multiple times, joining each + * observed type into the running erased type via [[`ErasedType.join`]]. */ trait HasManyMutableErasedType extends HasOnceMutableErasedType: - /** Tracks whether [[`observeErasedTypeAssign`]] has seen at least one assignment to this symbol. */ + /** Tracks whether [[`observeErasedTypeAssign`]] has seen at least one assignment to this symbol. Needed to + * distinguish a never-observed symbol (whose `N` erased type is the join unit, to be seeded by the first + * observation) from an observed-but-poisoned one (whose `N` erased type is the absorbing top). */ private var erasedTypeObserved: Bool = false /** * Observes an assignment of a value with type `observed` to this symbol. * - * Some symbols (e.g. [[`LocalVarSymbol`]]) can be assigned to multiple times with different types of values. This - * method tracks multiple assignments by coercing the type of the symbol to the top type if a subsequent assignemnt - * does not share the same type as the previously populated type. + * Some symbols (e.g. [[`LocalVarSymbol`]]) can be assigned multiple times with values of differing types. The + * first observation of an otherwise-unset symbol seeds the erased type directly; every subsequent observation + * (and any observation joining an annotation-populated type) is folded in with [[`ErasedType.join`]]. Under that + * join an unknown (`N`) assignment is "poison": it widens the symbol to the unknown top type and is absorbing, + * while two differing known types widen to the definitive top type [[`ErasedType.ObjectRef`]]. */ - // TODO(Derppening): We should probably coerce to the common supertype rather than directly to the top type def observeErasedTypeAssign(observed: Opt[ErasedType]): Unit = if !erasedTypeObserved && erasedType.isEmpty then erasedType = observed - else if erasedType != observed then - erasedType = S(ErasedType.ObjectRef) + else + erasedType = ErasedType.join(erasedType, observed) erasedTypeObserved = true extension (lit: Literal) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 38e5cbba5e..6ce253cb52 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -1427,9 +1427,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): private def populateRestParamErasedType(p: Param): Unit = p.sym.populateErasedType(ErasedType.Primitive(PrimitiveType.Array)) - /** Infers the [[`ErasedType`]] of a function's return type, by inspecting the erased type of all return values. */ - // TODO(Derppening): This should return `N` only if any return value is `N` - Conflicting known return types should - // be joined to `AnyRef` (or common ancestor) + /** Infers the [[`ErasedType`]] of a function's return type by joining the erased types of all return values + * via [[`ErasedType.join`]]. + */ private def inferReturn(body: Block): Opt[ErasedType] = var rets: Ls[Opt[ErasedType]] = Nil new BlockTraverserShallow: @@ -1438,8 +1438,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case _ => super.applyBlock(b) .applyBlock(body) rets match - case head :: tail if tail.forall(_ == head) => head - case _ => N + case head :: tail => tail.foldLeft(head)(ErasedType.join) + case Nil => N /** Populates a function definition symbol's erased type with a [[`ErasedType.FuncRef`]] derived from its * (already-populated) parameter symbols and return type. diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index 185b13c80a..30c48cd84a 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -246,8 +246,8 @@ fun counter() = //│ set scope0$cap.n$0¹ = tmp; //│ return runtime⁰.Unit⁰ //│ }; -//│ define counter⁰ as fun counter¹(): Object { -//│ let n: Object, scope0$cap: Capture$scope0²; +//│ define counter⁰ as fun counter¹(): ? { +//│ let n: ?, scope0$cap: Capture$scope0²; //│ set scope0$cap = new mut Capture$scope0²(n); //│ set scope0$cap.n$0¹ = 0; //│ do bump¹(scope0$cap); From 4085dffedb790c274ad55b0e12cd9bf911856ad0 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 18:53:26 +0800 Subject: [PATCH 38/48] codegen: Propagate erased types of capture fields --- .../src/main/scala/hkmc2/codegen/Lifter.scala | 7 +++- .../src/test/mlscript/codegen/ErasedType.mls | 37 ++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index 8575e237bc..07349a15a6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -567,9 +567,12 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val nme = sym.nme + "$" + id val ident = new Tree.Ident(nme) - val varSym = VarSymbol(ident, erasedType = N) + val capturedType = sym match + case s: HasErasedType => s.erasedType + case _ => N + val varSym = VarSymbol(ident, erasedType = capturedType) val fldSym = BlockMemberSymbol(nme, Nil) - val tSym = TermSymbol(syntax.MutVal, S(clsSym), ident, erasedType = N) + val tSym = TermSymbol(syntax.MutVal, S(clsSym), ident, erasedType = capturedType) val p = Param(FldFlags.empty.copy(isVal = true), varSym, N, Modulefulness.none) varSym.decl = S(p) // * Currently this is only accessed to create the class' toString method diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index 30c48cd84a..e0dc959102 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -224,7 +224,10 @@ id(new Foo(1)) // Closure capture (`:lift`): the symbol holding the capture object is typed as -// the generated capture class. +// the generated capture class. The capture field inherits the captured local's +// erased type — but here `n` is reassigned from `n + 1`, whose operand type is +// not yet known when the (hoisted) closure body is lowered, so the local (and +// hence the field) widens to `?`. :siret :lift fun counter() = @@ -254,3 +257,35 @@ fun counter() = //│ return scope0$cap.n$0¹ //│ }; //│ end + + +// When the captured local keeps a definite erased type (here `n` is only ever +// assigned `Int` literals), the generated capture field inherits it: both the +// constructor parameter `n$0` and the `val n$0` field are typed `Int`. +:siret +:lift +fun resetCounter() = + let n = 0 + fun reset() = set n = 100 + reset() + n +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let reset⁰, resetCounter⁰, Capture$scope0³; +//│ define Capture$scope0³ as class Capture$scope0⁵ { +//│ constructor Capture$scope0⁴(n$0: Int) { +//│ define n$0² as val n$0³: Int = n$0; +//│ end +//│ } +//│ }; +//│ define reset⁰ as fun reset¹(scope0$cap: ?): Unit⁰ { +//│ set scope0$cap.n$0³ = 100; +//│ return runtime⁰.Unit⁰ +//│ }; +//│ define resetCounter⁰ as fun resetCounter¹(): Int { +//│ let n: Int, scope0$cap: Capture$scope0⁵; +//│ set scope0$cap = new mut Capture$scope0⁵(n); +//│ set scope0$cap.n$0³ = 0; +//│ do reset¹(scope0$cap); +//│ return scope0$cap.n$0³ +//│ }; +//│ end From f03bba6c3b9431dd2f78e4cc371dba0cf3a98ec3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 18:59:54 +0800 Subject: [PATCH 39/48] codegen: Propagate erased types of params-from-lifted locals --- .../src/main/scala/hkmc2/codegen/Lifter.scala | 5 ++++- .../src/test/mlscript/codegen/ErasedType.mls | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index 07349a15a6..81a7fa0095 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -1002,7 +1002,10 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): class LiftedFunc(override val obj: ScopedObject.Func)(using ctx: LifterCtxNew) extends LiftedScope[FunDefn](obj) with GenericRewrittenScope[FunDefn]: private val passedSymsMap_ : Map[ValueSymbol, VarSymbol] = passedSymsOrdered.map: s => - s -> VarSymbol(Tree.Ident(s.nme), erasedType = N) + val erasedType = s match + case h: HasErasedType => h.erasedType + case _ => N + s -> VarSymbol(Tree.Ident(s.nme), erasedType) .toMap private val capSymsMap_ : Map[ScopedInfo, VarSymbol] = capturesOrdered.map: i => val nme = data.getNode(i).obj.nme diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls index e0dc959102..5d4aacb5e9 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -289,3 +289,22 @@ fun resetCounter() = //│ return scope0$cap.n$0³ //│ }; //│ end + + +// A read-only capture is lifted into a by-value parameter rather than a capture +// field; that parameter inherits the captured local's erased type (`k: Int`). +// (`rd`'s return stays `?`: the hoisted closure body is lowered before `k` is +// seeded, so the return-type inference cannot yet see `k`'s type.) +:siret +:lift +fun reader() = + let k = 0 + fun rd() = k + rd() +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let rd⁰, reader⁰; +//│ define rd⁰ as fun rd¹(k: Int): ? { +//│ return k +//│ }; +//│ define reader⁰ as fun reader¹(): ? { let k: Int; set k = 0; return rd¹(k) }; +//│ end From 7656e2bbe9154d7159e38d92f61365860c69143d Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 19:30:16 +0800 Subject: [PATCH 40/48] codegen: Expand `FuncRef` and remove top-level optionality --- hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala | 7 +++---- hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala | 2 +- hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala | 5 ++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 52dc3b48e2..dfaaf06441 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -890,7 +890,7 @@ enum ErasedType: case AnyRef(rsc: Bool, csym: ClassLikeSymbol | NoSymbol) /** A reference to a function of a possibly-known shape. */ - case FuncRef(sig: Opt[Ls[Opt[ErasedType]] -> Opt[ErasedType]]) + case FuncRef(params: Ls[Opt[ErasedType]], ret: Opt[ErasedType]) /** An primitive type. */ case Primitive(prim: PrimitiveType) @@ -899,7 +899,7 @@ enum ErasedType: def sym(using Ctx, State): ClassLikeSymbol = this match case AnyRef(_, csym: ClassLikeSymbol) => csym case AnyRef(_, _: NoSymbol) => ctx.builtins.Object - case FuncRef(_) => ctx.builtins.Function + case FuncRef(_, _) => ctx.builtins.Function case Primitive(prim) => prim.sym /** Trait representing a Block IR element that has an [[`ErasedType`]]. */ @@ -1052,9 +1052,8 @@ sealed abstract class Result extends AutoLocated, HasErasedType: case Call(Value.SimpleRef(bs: BuiltinSymbol), argss) => bs.resultErasedType(argss.head.map(_.value.erasedType)) case Call(fun, _) => fun.targetSymbol match - // * A call's result is the callee's return type, recovered from its `FuncRef` (Phase F.2). case S(ts: TermSymbol) => ts.erasedType match - case S(ErasedType.FuncRef(sig)) => sig.flatMap(_._2) + case S(ErasedType.FuncRef(_, ret)) => ret case _ => N case _ => N // * A resolved selection has the type of the member it refers to (e.g. `this.field`); an diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 6ce253cb52..2715984028 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -1451,7 +1451,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): if tsym.erasedType.isEmpty then val params = paramLists.flatMap(_.params).map(_.sym.erasedType) val ret = sign.flatMap(eraseSign) orElse inferReturn(body) - tsym.erasedType = S(ErasedType.FuncRef(S(params -> ret))) + tsym.erasedType = S(ErasedType.FuncRef(params, ret)) def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str]) (using LoweringCtx): (List[ParamList], Block) = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index 129c760d3a..5c10b26996 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -33,8 +33,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): def print(et: ErasedType)(using Scope): Document = et match case ErasedType.AnyRef(rsc, csym: ClassLikeSymbol) => doc"${if rsc then "rsc " else ""}${print(csym)}" case ErasedType.AnyRef(rsc, _: NoSymbol) => doc"${if rsc then "rsc " else ""}Object" - case ErasedType.FuncRef(N) => doc"Function" - case ErasedType.FuncRef(S(params -> ret)) => + case ErasedType.FuncRef(params, ret) => doc"(${params.map(_.fold(doc"?")(print)).mkDocument(sep = doc", ")}) => ${ret.fold(doc"?")(print)}" case ErasedType.Primitive(prim) => doc"${prim.toString}" @@ -47,7 +46,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): def returnTypeAnnot(dSym: TermSymbol)(using Scope): Document = if !summon[ShowCfg].showErasedTypes then doc"" else dSym.erasedType match - case S(ErasedType.FuncRef(sig)) => doc": ${sig.flatMap(_._2).fold(doc"?")(print)}" + case S(ErasedType.FuncRef(_, ret)) => doc": ${ret.fold(doc"?")(print)}" case _ => doc"" def print(blk: Block)(using Scope): Document = blk match From 6aa62d6cd603620fc899f43abf18ed5c6d2d3b4e Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 19:44:45 +0800 Subject: [PATCH 41/48] semantics: Make `NoSymbol` a `case object` --- .../src/main/scala/hkmc2/codegen/Block.scala | 20 +++++++++---------- .../hkmc2/codegen/BlockTransformer.scala | 2 +- .../scala/hkmc2/codegen/BlockTraverser.scala | 2 +- .../scala/hkmc2/codegen/HandlerLowering.scala | 2 +- .../src/main/scala/hkmc2/codegen/Lifter.scala | 2 +- .../main/scala/hkmc2/codegen/Printer.scala | 4 ++-- .../codegen/ReflectionInstrumenter.scala | 4 ++-- .../hkmc2/codegen/SpecializedSwitch.scala | 2 +- .../hkmc2/codegen/StackSafeTransform.scala | 4 ++-- .../scala/hkmc2/codegen/SymbolRefresher.scala | 2 +- .../scala/hkmc2/codegen/UsedVarAnalyzer.scala | 4 ++-- .../codegen/flowAnalysis/FlowAnalysis.scala | 2 +- .../scala/hkmc2/codegen/js/JSBuilder.scala | 2 +- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 2 +- .../scala/hkmc2/semantics/Elaborator.scala | 2 +- .../main/scala/hkmc2/semantics/Symbol.scala | 2 +- .../src/main/scala/hkmc2/utils/utils.scala | 3 ++- 17 files changed, 31 insertions(+), 30 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index dfaaf06441..dde5b3b4d7 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -46,7 +46,7 @@ case class Program( type SimpleSymbol = LocalVarSymbol | BuiltinSymbol /** Symbol that can be used as the left-hand side of an `Assign`. */ -type Assignable = LocalVarSymbol | NoSymbol +type Assignable = LocalVarSymbol | NoSymbol.type /** Symbols that `Scoped` introduces as block-local bindings. * This deliberately excludes things like `TermSymbol`s, which never need to be scoped. @@ -148,7 +148,7 @@ sealed abstract class Block extends Product: lazy val definedVars: Set[BoundSymbol] = this match case _: Return | _: Throw | _: Unreachable => Set.empty case Begin(sub, rst) => sub.definedVars ++ rst.definedVars - case Assign(_: NoSymbol, r, rst) => rst.definedVars + case Assign(NoSymbol, r, rst) => rst.definedVars case Assign(l: LocalVarSymbol, r, rst) => rst.definedVars + l case AssignField(l, n, r, rst) => rst.definedVars case AssignDynField(l, n, ai, r, rst) => rst.definedVars @@ -201,7 +201,7 @@ sealed abstract class Block extends Product: case Continue(label) => Set.single(label) case Begin(sub, rest) => sub.freeVars ++ rest.freeVars case TryBlock(sub, finallyDo, rest) => sub.freeVars ++ finallyDo.freeVars ++ rest.freeVars - case Assign(_: NoSymbol, rhs, rest) => rhs.freeVars ++ rest.freeVars + case Assign(NoSymbol, rhs, rest) => rhs.freeVars ++ rest.freeVars case Assign(lhs: LocalVarSymbol, rhs, rest) => Set.single(lhs) ++ rhs.freeVars ++ rest.freeVars case AssignField(lhs, nme, rhs, rest) => lhs.freeVars ++ rhs.freeVars ++ rest.freeVars case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => lhs.freeVars ++ fld.freeVars ++ rhs.freeVars ++ rest.freeVars @@ -435,14 +435,14 @@ object Assign: case Scoped(syms, body) => Scoped(syms, Assign(lhs, rhs, body)) case _ => lhs match - case _: NoSymbol => + case NoSymbol => if rhs.isPure then rest else new Assign(lhs, rhs, rest) case _ => new Assign(lhs, rhs, rest) def discard(res: Result, rest: Block)(using State): Block = res match case _: Value | _: Lambda => rest case p: Path if p.isPure => rest - case r => Assign(State.noSymbol, r, rest) + case r => Assign(NoSymbol, r, rest) object AssignField: def apply(lhs: Path, nme: Tree.Ident, rhs: Result, rest: Block)(symbol: Opt[MemberSymbol]): Block = rest match case Scoped(syms, body) => Scoped(syms, AssignField(lhs, nme, rhs, body)(symbol)) @@ -857,7 +857,7 @@ enum PrimitiveType: case Array => ctx.builtins.Array object ErasedType: - def ObjectRef: ErasedType.AnyRef = AnyRef(rsc = false, NoSymbol()) + def ObjectRef: ErasedType.AnyRef = AnyRef(rsc = false, NoSymbol) /** Maps a [[`ClassLikeSymbol`]] into the canonical [[`ErasedType`]]. */ def fromClsLikeSymbol(csym: ClassLikeSymbol, rsc: Bool)(using Ctx, State): ErasedType = @@ -887,7 +887,7 @@ enum ErasedType: * * - `rsc` is true if this reference is a resource class. */ - case AnyRef(rsc: Bool, csym: ClassLikeSymbol | NoSymbol) + case AnyRef(rsc: Bool, csym: ClassLikeSymbol | NoSymbol.type) /** A reference to a function of a possibly-known shape. */ case FuncRef(params: Ls[Opt[ErasedType]], ret: Opt[ErasedType]) @@ -898,7 +898,7 @@ enum ErasedType: /** The symbol for this erased type. */ def sym(using Ctx, State): ClassLikeSymbol = this match case AnyRef(_, csym: ClassLikeSymbol) => csym - case AnyRef(_, _: NoSymbol) => ctx.builtins.Object + case AnyRef(_, NoSymbol) => ctx.builtins.Object case FuncRef(_, _) => ctx.builtins.Function case Primitive(prim) => prim.sym @@ -1136,7 +1136,7 @@ object Value: @deprecated("Use Value.SimpleRef, Value.MemberRef, or Value.This instead.") object Ref: - def apply(l: ValueSymbol | NoSymbol, disamb: Opt[DefinitionSymbol[?]]): Value.RefLike = + def apply(l: ValueSymbol | NoSymbol.type, disamb: Opt[DefinitionSymbol[?]]): Value.RefLike = l match case l: SimpleSymbol => l.asSimpleRef case l: TermSymbol => l.asPath @@ -1144,7 +1144,7 @@ object Value: disamb.getOrElse: lastWords(s"Cannot disambiguate overloaded member symbol ${bms.nme}: no disambiguation provided") case sym: InnerSymbol => sym.asThis - case _: NoSymbol => lastWords("NoSymbol should not be used as a Path/Value") + case NoSymbol => lastWords("NoSymbol should not be used as a Path/Value") // * Some helper constructors that allow omitting the disambiguation symbol. // * If the ref itself is a DefinitionSymbol, then disambiguating it results in itself. diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala index a1ef76af45..f376eed3cc 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala @@ -213,7 +213,7 @@ class BlockTransformer(subst: SymbolSubst): case sym: BlockMemberSymbol => sym.subst def applyAssignLhs(sym: Assignable): Assignable = sym match - case sym: NoSymbol => sym + case NoSymbol => NoSymbol case sym: TempSymbol => sym.subst case sym: VarSymbol => sym.subst diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala index 0dfbf843bc..ec06cc9470 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala @@ -25,7 +25,7 @@ class BlockTraverser: def applyMaybeSymbol(sym: MaybeSymbol): Unit = sym match - case _: NoSymbol => () + case NoSymbol => () case sym: Symbol => applySymbol(sym) def applySymbol(sym: Symbol): Unit = () diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index 434f108ad4..141c88bb8d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -390,7 +390,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, applyResult(rhs) lhs match case lhs: LocalVarSymbol => assignToSym(lhs) - case _: NoSymbol => + case NoSymbol => applyBlock(rest) case Define(defn: ValDefn, rest) => applyPath(defn.rhs) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index 81a7fa0095..c1c8f007d9 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -510,7 +510,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): if (sub2 is sub) && (fin2 is fin) && (rst2 is rst) then rewritten else TryBlock(sub2, fin2, rst2) // Assignment to variables - case Assign(_: NoSymbol, _, _) => super.applyBlock(rewritten) + case Assign(NoSymbol, _, _) => super.applyBlock(rewritten) case Assign(lhs: LocalVarSymbol, rhs, rest) => ctx.symbolsMap.get(lhs) match case Some(path) => applyResult(rhs): rhs2 => path.assign(rhs2, applySubBlock(rest)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index 5c10b26996..a340ff1e15 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -32,7 +32,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): def print(et: ErasedType)(using Scope): Document = et match case ErasedType.AnyRef(rsc, csym: ClassLikeSymbol) => doc"${if rsc then "rsc " else ""}${print(csym)}" - case ErasedType.AnyRef(rsc, _: NoSymbol) => doc"${if rsc then "rsc " else ""}Object" + case ErasedType.AnyRef(rsc, NoSymbol) => doc"${if rsc then "rsc " else ""}Object" case ErasedType.FuncRef(params, ret) => doc"(${params.map(_.fold(doc"?")(print)).mkDocument(sep = doc", ")}) => ${ret.fold(doc"?")(print)}" case ErasedType.Primitive(prim) => doc"${prim.toString}" @@ -75,7 +75,7 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): doc"begin #{ # ${print(sub)}; #} # ${print(rest)}" case TryBlock(sub, finallyDo, rest) => doc"try #{ # ${print(sub)} #} # finally #{ # ${print(finallyDo)}; # #} ${print(rest)}" - case Assign(_: NoSymbol, rhs, rest) => + case Assign(NoSymbol, rhs, rest) => doc"do ${print(rhs)}; # ${print(rest)}" case Assign(lhs: (LocalVarSymbol | TermSymbol), rhs, rest) => doc"set ${print(lhs)} = ${print(rhs)}; # ${print(rest)}" diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala index 41bb324c8d..0037f4c8a7 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala @@ -121,7 +121,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n blockCtor("ConcreteClassSymbol", Ls(toValue(name), path, paramsOpt, auxParams), symName)(k) case _: ModuleOrObjectSymbol => blockCtor("ModuleSymbol", Ls(toValue(name), path), symName)(k) - case _: NoSymbol => + case NoSymbol => blockCtor("NoSymbol", Nil, symName)(k) case sym: LocalVarSymbol => val name = scope.allocateOrGetName(sym) @@ -288,7 +288,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n blockCtor("ValueSimpleRef", Ls(xSym)): xStaged => (Assign(x, xStaged, _)): given Context = x match - case _: NoSymbol => ctx.clone() + case NoSymbol => ctx.clone() case x: ValueSymbol => ctx.clone() += x.asPath -> xStaged transformBlock(b): (z, ctx) => blockCtor("Assign", Ls(xSym, y, z), "assign")(k(_, ctx)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala index 31eb551088..6b8d3905ed 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/SpecializedSwitch.scala @@ -150,7 +150,7 @@ private object PostCondAnalysisImpl extends CachedAnalysis[Block, PostCondRes]: case Scoped(syms, body) => analyze(body) case Begin(sub, rest) => analyze(sub) >=> analyze(rest) case TryBlock(sub, finallyDo, rest) => analyze(sub) >=> analyze(finallyDo) >=> analyze(rest) - case Assign(_: NoSymbol, rhs, rest) => res(N, rhs, rest) + case Assign(NoSymbol, rhs, rest) => res(N, rhs, rest) case Assign(lhs: ValueSymbol, rhs, rest) => res(S(lhs), rhs, rest) case AssignField(path, _, rhs, rest) => res(N, rhs, rest) case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => res(N, rhs, rest) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala index ec37a78184..90af68aba2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala @@ -33,7 +33,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S .rest: sym match case sym: LocalVarSymbol => f(sym.asSimpleRef) - case _: NoSymbol => f(Value.Lit(Tree.UnitLit(false))) + case NoSymbol => f(Value.Lit(Tree.UnitLit(false))) def wrapStackSafe(body: Block, resSym: Assignable, rest: Block) = val bodSym = BlockMemberSymbol("‹stack safe body›", Nil, false) @@ -45,7 +45,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S def extractResTopLevel(res: Result, isTailCall: Bool, f: Result => Block, sym: Assignable, curDepth: => LocalVarSymbol) = sym match case sym: LocalVarSymbol => wrapStackSafe(Ret(res), sym, f(sym.asSimpleRef)) - case _: NoSymbol => wrapStackSafe(Ret(res), sym, f(Value.Lit(Tree.UnitLit(false)))) + case NoSymbol => wrapStackSafe(Ret(res), sym, f(Value.Lit(Tree.UnitLit(false)))) // Rewrites anything that can contain a Call to increase the stack depth def transform(b: Block, curDepth: => LocalVarSymbol, isTopLevel: Bool = false): Block = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala index 1a0e71cc9c..593524f84e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala @@ -156,7 +156,7 @@ private class SymbolRefresherInternal(m: MutMap[Symbol, Symbol])(using State) ex override def applyImportSymbol(s: ImportSymbol): ImportSymbol = m.getOrElse(s, s).asInstanceOf[ImportSymbol] override def applyAssignLhs(s: Assignable): Assignable = s match - case s: NoSymbol => s + case NoSymbol => NoSymbol case s: LocalVarSymbol => m.getOrElse(s, s).asInstanceOf[LocalVarSymbol] class SymbolRefresher(m: Map[Symbol, Symbol])(using State) extends SymbolRefresherInternal(MutMap.from(m)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala index 74b695b0f8..1e134dbf5a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala @@ -58,7 +58,7 @@ class UsedVarAnalyzer(b: Block, scopeData: ScopeData)(using State): accessed.refdDefns.add(scopeData.getUID(s)) case Assign(lhs, rhs, rest) => lhs match - case _: NoSymbol => () + case NoSymbol => () case lhs: ScopedOrInnerSymbol => accessed.accessed.add(lhs) accessed.mutated.add(lhs) @@ -320,7 +320,7 @@ class UsedVarAnalyzer(b: Block, scopeData: ScopeData)(using State): case Assign(lhs, rhs, rest) => applyResult(rhs) lhs match - case _: NoSymbol => () + case NoSymbol => () case lhs: ScopedOrInnerSymbol => if hasReader.contains(lhs) || hasMutator.contains(lhs) then reqCapture += lhs if !linearValueVars.contains(lhs) then mutated += lhs diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala index adedd750c2..d9808c5801 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/flowAnalysis/FlowAnalysis.scala @@ -893,7 +893,7 @@ class FlowConstraintsCollector( case Assign(lhs, rhs, rest) => val rhsStrat = processResult(rhs) lhs.match - case _: NoSymbol => () + case NoSymbol => () case lhs: (LocalVarSymbol | TermSymbol) => cc.constrain(rhsStrat, generatedProdVars(lhs).asConsStrat) processBlock(rest) case TryBlock(sub, finallyDo, rest) => diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index e2b5a36aeb..7543780cb3 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -338,7 +338,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: def returningTerm(t: Block, endSemi: Bool)(using Raise, Scope): Document = def mkSemi = if endSemi then ";" else "" t match - case Assign(l: NoSymbol, r, rst) => + case Assign(NoSymbol, r, rst) => doc" # ${result(r)};${returningTerm(rst, endSemi)}" case Assign(l: (LocalVarSymbol | TermSymbol), r, rst) => doc" # ${ diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 1cf70befa3..0d097c375c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -1664,7 +1664,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: def returningTerm(t: Block)(using Ctx, FunctionCtx, Raise, SessionExportCtx): Expr = t match - case Assign(l: NoSymbol, r, rst) => + case Assign(NoSymbol, r, rst) => val rExpr = result(r) val evalExpr = rExpr.resultType match case S(_) => drop(rExpr) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala index 78bcaaf737..521e2b2e4f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala @@ -336,7 +336,7 @@ object Elaborator: val strSymbol = ModuleOrObjectSymbol(DummyTypeDef(syntax.Mod), Ident("Str")) // In JavaScript, `import` can be used for getting current file path, as `import.meta` val importSymbol = new VarSymbol(Ident("import"), erasedType = N) - val noSymbol = NoSymbol() + val noSymbol = NoSymbol val runtimeSymbol = TempSymbol(N, erasedType = N, "runtime") val definitionMetadataSymbol = TempSymbol(N, erasedType = N, "definitionMetadata") val prettyPrintSymbol = TempSymbol(N, erasedType = N, "prettyPrint") diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index 75abc73700..7eec9b9438 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -155,7 +155,7 @@ end Symbol // * Used, eg, as the Assign receiver of intermediate computations whose result is not used -final class NoSymbol extends MaybeSymbol: +case object NoSymbol extends MaybeSymbol: def nme: Str = "‹no symbol›" override def toString: Str = nme diff --git a/hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala b/hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala index 4b6ca281bf..8e418ebc61 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala @@ -92,7 +92,8 @@ class DebugPrinter: case codegen.Scoped(syms, body) => val symsStr = "{" + syms.toArray.sortBy(_.uid).map(_.showAsPlain).mkString(", ") + "}" s"Scoped(syms = $symsStr): \n" + s"body = ${printProduct(false, body)}".indent(" ") - + + case semantics.NoSymbol => printPlain(semantics.NoSymbol) case t: Product => printProduct(inTailPos, t) case v => printPlain(v) From 8fdc420c1cb16d79d1b321d3a724ec13ae31d2cc Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 20:02:26 +0800 Subject: [PATCH 42/48] semantics: Make `NoSymbol` a plain `object` --- hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala | 2 +- hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index 7eec9b9438..6f5bdf98b0 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -155,7 +155,7 @@ end Symbol // * Used, eg, as the Assign receiver of intermediate computations whose result is not used -case object NoSymbol extends MaybeSymbol: +object NoSymbol extends MaybeSymbol: def nme: Str = "‹no symbol›" override def toString: Str = nme diff --git a/hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala b/hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala index 8e418ebc61..093fbf1e47 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala @@ -93,7 +93,6 @@ class DebugPrinter: val symsStr = "{" + syms.toArray.sortBy(_.uid).map(_.showAsPlain).mkString(", ") + "}" s"Scoped(syms = $symsStr): \n" + s"body = ${printProduct(false, body)}".indent(" ") - case semantics.NoSymbol => printPlain(semantics.NoSymbol) case t: Product => printProduct(inTailPos, t) case v => printPlain(v) From fc5367d3ccfcf1fdfaa7dbd78d1b55b679e0ce6c Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 20:05:56 +0800 Subject: [PATCH 43/48] semantics: Refactor and deprecate `State.noSymbol` --- hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala | 2 +- hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala | 1 + .../src/main/scala/hkmc2/semantics/ucs/Normalization.scala | 2 +- hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index dde5b3b4d7..584831c103 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -556,7 +556,7 @@ object HandleBlock: N, Nil, S(par), handlerMtds, Nil, Nil, // Apparently, the lifter is not happy with any assignment in the preCtor... - Assign(State.noSymbol, Call(State.builtinOpsMap("super").asSimpleRef, args.map(_.asArg) ne_:: Nil)(true, true, false), End()), + Assign(NoSymbol, Call(State.builtinOpsMap("super").asSimpleRef, args.map(_.asArg) ne_:: Nil)(true, true, false), End()), End(), N, N, diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala index 521e2b2e4f..06ca97c5a8 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala @@ -336,6 +336,7 @@ object Elaborator: val strSymbol = ModuleOrObjectSymbol(DummyTypeDef(syntax.Mod), Ident("Str")) // In JavaScript, `import` can be used for getting current file path, as `import.meta` val importSymbol = new VarSymbol(Ident("import"), erasedType = N) + @deprecated("Use the `NoSymbol` singleton instead.") val noSymbol = NoSymbol val runtimeSymbol = TempSymbol(N, erasedType = N, "runtime") val definitionMetadataSymbol = TempSymbol(N, erasedType = N, "definitionMetadata") diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index 23e36eee78..db80fec130 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -466,7 +466,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C form match case IfLikeForm.ReturningIf => if (k is Ret) || (k is Thrw) then k(r) else Assign(l, r, End()) case IfLikeForm.ImperativeIf => Assign.discard(r, End()) - case IfLikeForm.While => Assign(State.noSymbol, r, loopCont) + case IfLikeForm.While => Assign(NoSymbol, r, loopCont) // NOTE: `shouldRewriteWhile` is not the same as `config.rewriteWhileLoops` // as shouldRewriteWhile is always true when effect handler lowering is on lazy val loopCont = if config.shouldRewriteWhile diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala index 5d0e63e33c..43aaa81b45 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala @@ -296,7 +296,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val le = import codegen.* Assign( - Elaborator.State.noSymbol, + NoSymbol, Call( Elaborator.State.runtimeSymbol.asSimpleRef.selSN("printRaw"), (Arg(N, sym.asPath) :: Nil) ne_:: Nil)(true, false, false), From 098ef027036f9a1322102c27cb12d3a4439fbfb7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jun 2026 20:12:50 +0800 Subject: [PATCH 44/48] Fix more `State.noSymbol` deprecation warnings --- .../shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala | 2 +- hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala | 4 ++-- .../src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index 141c88bb8d..8217d63b94 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -750,7 +750,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val ctx = HandlerCtx.TopLevel val transformed = translateBlock(preTransformed, ctx, Set.empty) val blk = blockBuilder - .assign(State.noSymbol, Call(paths.resetEffects, Nil ne_:: Nil)(true, false, false)) + .assign(NoSymbol, Call(paths.resetEffects, Nil ne_:: Nil)(true, false, false)) .rest(transformed) (blk, stackSafetyMap) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 2715984028..ee0e8bf82f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -154,7 +154,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): isTailCall = false, args, N, // TODO: location? - )(c => Assign(State.noSymbol, c, End())) + )(c => Assign(NoSymbol, c, End())) // * Used to work around Scala's @tailrec annotation for those few calls that are not in tail position. final def term_nonTail(t: st, inStmtPos: Bool = false)(k: Result => Block)(using LoweringCtx): Block = @@ -1525,7 +1525,7 @@ trait LoweringSelSanityChecks(using Config, TL, Raise, State) // * the access should throw an error like `TypeError: Cannot read property 'f' of undefined`. blockBuilder .assign(selRes, Select(p, nme)(disamb)) - .assign(State.noSymbol, Select(p, Tree.Ident(nme.name+"$__checkNotMethod"))(N)) + .assign(NoSymbol, Select(p, Tree.Ident(nme.name+"$__checkNotMethod"))(N)) .ifthen(selRes.asSimpleRef, Case.Lit(syntax.Tree.UnitLit(false)), Throw(Instantiate(mut = false, Select(State.globalThisSymbol.asThis, Tree.Ident("Error"))(N), diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 0d097c375c..db9090f72d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -954,7 +954,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: def splitSuperTail(block: Block): Opt[Block -> Ls[Arg]] = block match case End(_) => N case Assign(lhs, Call(Value.SimpleRef(bs: BuiltinSymbol), argss), _: End) - if (lhs is State.noSymbol) && (bs is State.superSymbol) + if (lhs is NoSymbol) && (bs is State.superSymbol) => S(End("") -> argss.flatten) case b: NonBlockTail => From 3665ba22ef0db17c9dc2696efcf60c81dc855b26 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 2 Jul 2026 13:15:34 +0800 Subject: [PATCH 45/48] WIP: Implement lowering of `Int` -> `i31ref` --- .../scala/hkmc2/codegen/wasm/text/Ctx.scala | 31 +++++++---- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 35 +++++++++--- .../shared/src/test/mlscript/wasm/Basics.mls | 53 ++++++++++++++++++- 3 files changed, 103 insertions(+), 16 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala index dd42f7c1d0..4f44107a4a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala @@ -191,6 +191,10 @@ final class SessionExportCtx( * The expression of the function body. * @param exportName * Optional export name. + * @param typedParams + * Whether parameter slots should be declared with their `erasedType`-derived Wasm type (via + * `ValueSymbol.paramRefType`) rather than uniformly `anyref`. Only safe for top-level free functions, which are not + * subject to the shared vtable calling convention; `false` for everything else (methods, ctors, `init`). */ class FuncInfo( val sym: BlockMemberSymbol | TempSymbol, @@ -201,6 +205,7 @@ class FuncInfo( val body: Expr, val exportName: Opt[Str], val wrapId: Opt[Str] -> Opt[Str] = N -> N, + val typedParams: Bool = false, )(using Ctx, Raise, State) extends ToWat: /** Symbolic identifier for the function. */ @@ -208,7 +213,7 @@ class FuncInfo( /** Returns the type of this function as a [[SignatureType]]. */ def getSignatureType: SignatureType = SignatureType( - params = params.map((_, paramIdx) => WasmParam(paramIdx, RefType.anyref)), + params = params.map((sym, paramIdx) => WasmParam(paramIdx, if typedParams then sym.paramRefType else RefType.anyref)), results = resultTypes, ) @@ -360,8 +365,11 @@ object FunctionCtx: * The parameters of this function. * @param thisSym * The implicit `this` parameter symbol if this function is generated from a non-static method, or `N` otherwise. + * @param typedParams + * Whether parameter slots should be declared/loaded with their `erasedType`-derived Wasm type rather than uniformly + * `anyref`. */ -class FunctionCtx(_params: Ls[ParamList], thisSym: Opt[InnerSymbol])(using Raise, State): +class FunctionCtx(_params: Ls[ParamList], thisSym: Opt[InnerSymbol], typedParams: Bool = false)(using Raise, State): /** [[Scope]] for generating WAT identifiers of locals. */ private[text] val localScp = Scope.empty(Scope.Cfg.default) @@ -406,14 +414,18 @@ class FunctionCtx(_params: Ls[ParamList], thisSym: Opt[InnerSymbol])(using Raise /** The declared Wasm reference type of the param/local slot for `sym`. * - * Parameters are uniformly `anyref`: their declared type is fixed by the shared call/vtable - * calling convention, independent of `sym.erasedType` (e.g. a virtually-dispatched method's - * `this` must stay `anyref` to match the shared vtable signature even when its erased type names - * a concrete class). Local slots derive their type from the symbol's erased type via - * [[localRefType]]. + * Parameters are `anyref` by default: their declared type is fixed by the shared call/vtable + * calling convention. + * + * When `typedParams` is set (e.g. when compiling free functions), parameter slots instead derive their type from + * [[ValueSymbol.paramRefType]]. + * + * Local slots always derive their type from the symbol's erased type via [[localRefType]]. */ def slotRefType(sym: ValueSymbol)(using Ctx): RefType = - if params.exists(_._1 == sym) then RefType.anyref else sym.localRefType + if params.exists(_._1 == sym) then + if typedParams then sym.paramRefType else RefType.anyref + else sym.localRefType /** Pushes a label target for the dynamic extent of `body` and pops it afterwards. * @@ -455,8 +467,9 @@ end FunctionCtx def genFuncBody[T]( params: Ls[ParamList], thisSym: Opt[InnerSymbol], + typedParams: Bool = false, )(mkBody: FunctionCtx ?=> T)(using Raise, State): T -> FunctionCtx = - val funcCtx = FunctionCtx(params, thisSym) + val funcCtx = FunctionCtx(params, thisSym, typedParams) val result = mkBody(using funcCtx) result -> funcCtx diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 9b201f4531..13a52c10b3 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -28,7 +28,7 @@ extension (instr: FoldedInstr) extension (et: ErasedType) /** Returns the corresponding Wasm type for this [[`ErasedType`]]. */ - private def wasmType(using Ctx): Opt[RefType] = + private[text] def wasmType(using Ctx): Opt[RefType] = import Ctx.ctx et match case ErasedType.Primitive(PrimitiveType.Int | PrimitiveType.Int31 | PrimitiveType.Bool) => @@ -53,6 +53,13 @@ extension (sym: ValueSymbol) structSym.flatMap(ctx.getType).map(RefType(_, nullable = false)).getOrElse(RefType.anyref) case _ => RefType.anyref + /** The Wasm reference type a parameter slot for `sym` should be declared with, if typed parameters are enabed. */ + private[text] def paramRefType(using Ctx, State): RefType = + sym match + case s: HasErasedType => + s.erasedType.collect { case p: ErasedType.Primitive => p }.flatMap(_.wasmType).getOrElse(RefType.anyref) + case _ => RefType.anyref + extension (exprs: Seq[Expr]) /** Merges a sequence of expressions to a single `block` expression if needed. */ def mergeAsBlock: Opt[Expr] = @@ -165,6 +172,20 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: require(expr.resultTypes.size == 1, "expected single-result expression for cast") if expr.resultType.contains(target) then expr else ref.cast(expr, target) + /** Casts each argument in `wasmArgs` down to the corresponding declared parameter type read from `funcTypeInfo`, + * narrowing `anyref` -> a concrete typed parameter. + * + * Note that this function does not cast an argument that is already narrower than the declared parameter type, + * since it is illegal to upcast using `ref.cast`. + */ + private def castArgsToParams(wasmArgs: Seq[Expr], funcTypeInfo: TypeInfo): Seq[Expr] = + val declParams = funcTypeInfo.compType.asInstanceOf[FunctionType].sigType.params + wasmArgs.zip(declParams).map: (arg, p) => + p.valtype match + case rt: RefType if rt.heapType =/= HeapType.Any && arg.resultType.contains(RefType.anyref) => + castConserve(arg, rt) + case _ => arg + /** Returns the default Wasm value for one struct field when eagerly constructing an object instance. */ private def defaultStructFieldValue(field: Field)(using Ctx, Raise): Expr = field.ty match case refTy: RefType if refTy.nullable => ref.`null`(refTy.heapType) @@ -1318,7 +1339,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: extraInfo = S(fun.toString), ) val baseTypeInfo = ctx.getTypeInfo_!(ctx.getFuncTypeUse_!(baseFuncIdx).typeIdx) - val wasmArgs = args.map(argument) + val wasmArgs = castArgsToParams(args.map(argument), baseTypeInfo) call( funcidx = baseFuncIdx, @@ -1334,7 +1355,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: extraInfo = S(fun.toString), ) val baseTypeInfo = ctx.getTypeInfo_!(ctx.getFuncTypeUse_!(baseFuncIdx).typeIdx) - val wasmArgs = args.map(argument) + val wasmArgs = castArgsToParams(args.map(argument), baseTypeInfo) call( funcidx = baseFuncIdx, @@ -1800,13 +1821,13 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val result = pss.foldRight(bod): case (ps, block) => Return(Lambda(ps, block)(Nil)) - val (bodyWat, fnCtx) = setupFunction(N, ps, result) + val (bodyWat, fnCtx) = setupFunction(N, ps, result, typedParams = true) if sym.nameIsMeaningful then val funcTy = ctx.addType( TypeInfo( sym = TempSymbol(N, erasedType = N, sym.nme), compType = FunctionType( - params = fnCtx.params.map(p => WasmParam(p._2, RefType.anyref)), + params = fnCtx.params.map(p => WasmParam(p._2, p._1.paramRefType)), results = Seq.fill(bodyWat.resultTypes.length)(Result(RefType.anyref)), ), objectTag = N, @@ -1821,6 +1842,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: locals = fnCtx.locals, body = bodyWat, exportName = sym.optionIf(_.nameIsMeaningful).map(_.nme), + typedParams = true, ) ctx.addFunc(funcInfo) if summon[SessionExportCtx].shouldExport(defn.sym) then @@ -2422,8 +2444,9 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: thisParam: Opt[InnerSymbol], params: ParamList, body: Block, + typedParams: Bool = false, )(using Ctx, Raise, SessionExportCtx): (Expr, FunctionCtx) = - genFuncBody(params :: Nil, thisSym = thisParam): + genFuncBody(params :: Nil, thisSym = thisParam, typedParams = typedParams): block(body).mergeAsBlock.getOrElse(nop) end WatBuilder diff --git a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls index 85573c27d8..ad2f679c56 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls @@ -85,6 +85,57 @@ foo(41) //│ = 42 +// Tests whether cast elision and insertion works correctly: +// - `addOne(41)` in the top-level elides a cast since `41: i31` and `(ref i31) <: (ref null i31)` +// - `addOne(v)` in `useIt` inserts a cast since `v: anyref` and `(ref null any) :> (ref null i31)` +:noInline +:wat +fun addOne(x: Int) = x + 1 +fun useIt(v) = addOne(v) +addOne(41) + useIt(41) +//│ Wat: +//│ (module +//│ (type $TypeInfoBase (sub (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) +//│ (type $Object (sub (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) +//│ (type $plus_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) +//│ (type $addOne (func (param $x (ref null i31)) (result (ref null any)))) +//│ (type $useIt (func (param $v (ref null any)) (result (ref null any)))) +//│ (type $entry (func (result (ref null any)))) +//│ (import "system" "plus_impl" (func $plus_impl (type $plus_impl))) +//│ (func $addOne (export "addOne") (type $addOne) (param $x (ref null i31)) (result (ref null any)) +//│ (return +//│ (call $plus_impl +//│ (local.get $x) +//│ (ref.i31 +//│ (i32.const 1))))) +//│ (func $useIt (export "useIt") (type $useIt) (param $v (ref null any)) (result (ref null any)) +//│ (return +//│ (call $addOne +//│ (ref.cast (ref null i31) +//│ (local.get $v))))) +//│ (func $entry (export "entry") (type $entry) (result (ref null any)) +//│ (local $tmp (ref null any)) +//│ (local $tmp1 (ref null any)) +//│ (block (result (ref null any)) +//│ (local.set $tmp +//│ (call $addOne +//│ (ref.i31 +//│ (i32.const 41)))) +//│ (local.set $tmp1 +//│ (call $useIt +//│ (ref.i31 +//│ (i32.const 41)))) +//│ (return +//│ (call $plus_impl +//│ (local.get $tmp) +//│ (local.get $tmp1))))) +//│ (elem $addOne declare func $addOne) +//│ (elem $useIt declare func $useIt) +//│ (elem $entry declare func $entry)) +//│ Wasm result: +//│ = 84 + + class Foo(val x) new Foo(0) //│ Wasm result: @@ -283,5 +334,5 @@ fun bar() = 42 fun foo() = bar foo()() //│ ╔══[COMPILATION ERROR] Returning function instances is not supported -//│ ║ l.283: fun foo() = bar +//│ ║ l.334: fun foo() = bar //│ ╙── ^^^ From 91557689791608ff94743e8e9c28d31cda76ed1b Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 2 Jul 2026 16:05:13 +0800 Subject: [PATCH 46/48] WIP: Implement lowering of classes to specific refs --- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 4 +- .../shared/src/test/mlscript/wasm/Basics.mls | 65 ++++++++++++++++++- .../src/test/mlscript/wasm/VirtualMethods.mls | 4 +- 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 13a52c10b3..1a5fe13d23 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -53,11 +53,11 @@ extension (sym: ValueSymbol) structSym.flatMap(ctx.getType).map(RefType(_, nullable = false)).getOrElse(RefType.anyref) case _ => RefType.anyref - /** The Wasm reference type a parameter slot for `sym` should be declared with, if typed parameters are enabed. */ + /** The Wasm reference type a parameter slot for `sym` should be declared with, if typed parameters are enabled. */ private[text] def paramRefType(using Ctx, State): RefType = sym match case s: HasErasedType => - s.erasedType.collect { case p: ErasedType.Primitive => p }.flatMap(_.wasmType).getOrElse(RefType.anyref) + s.erasedType.flatMap(_.wasmType).getOrElse(RefType.anyref) case _ => RefType.anyref extension (exprs: Seq[Expr]) diff --git a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls index ad2f679c56..5a532c5cf6 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls @@ -202,6 +202,69 @@ getX(Foo(42)) //│ = 42 +// Checks that `f` in `getX` is typed as `(ref $Foo)` (evident by the type annotation) +:noInline +:wat +class Foo(val x) +fun getX(f: Foo) = f.x +getX(Foo(42)) +//│ Wat: +//│ (module +//│ (type $TypeInfoBase (sub (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) +//│ (type $Object (sub (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) +//│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) +//│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any)))))) +//│ (type $Foo_init (func (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)))) +//│ (type $Foo_ctor (func (param $x (ref null any)) (result (ref null any)))) +//│ (type $getX (func (param $f (ref $Foo)) (result (ref null any)))) +//│ (type $entry (func (result (ref null any)))) +//│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo +//│ (i32.const 1) +//│ (ref.null $TypeInfoBase))) +//│ (func $Foo_init (type $Foo_init) (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)) +//│ (block (result (ref null any)) +//│ (struct.set $Foo $x +//│ (ref.cast (ref $Foo) +//│ (local.get $this)) +//│ (local.get $x)) +//│ (return +//│ (local.get $this)))) +//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $x (ref null any)) (result (ref null any)) +//│ (local $this (ref $Foo)) +//│ (block (result (ref null any)) +//│ (local.set $this +//│ (struct.new $Foo +//│ (global.get $Foo_typeinfo) +//│ (ref.null any))) +//│ (drop +//│ (call $Foo_init +//│ (local.get $this) +//│ (local.get $x))) +//│ (return +//│ (local.get $this)))) +//│ (func $getX (export "getX") (type $getX) (param $f (ref $Foo)) (result (ref null any)) +//│ (return +//│ (struct.get $Foo $x +//│ (local.get $f)))) +//│ (func $entry (export "entry") (type $entry) (result (ref null any)) +//│ (local $tmp (ref null any)) +//│ (block (result (ref null any)) +//│ (local.set $tmp +//│ (call $Foo_ctor +//│ (ref.i31 +//│ (i32.const 42)))) +//│ (return +//│ (call $getX +//│ (ref.cast (ref $Foo) +//│ (local.get $tmp)))))) +//│ (elem $Foo_init declare func $Foo_init) +//│ (elem $Foo_ctor declare func $Foo_ctor) +//│ (elem $getX declare func $getX) +//│ (elem $entry declare func $entry)) +//│ Wasm result: +//│ = 42 + + :wat class Foo(val x) with val y = this.x @@ -334,5 +397,5 @@ fun bar() = 42 fun foo() = bar foo()() //│ ╔══[COMPILATION ERROR] Returning function instances is not supported -//│ ║ l.334: fun foo() = bar +//│ ║ l.397: fun foo() = bar //│ ╙── ^^^ diff --git a/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls b/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls index b8060a5c95..ecfc3a5453 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls @@ -33,7 +33,7 @@ callF(B()) //│ (type $B (sub $A (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) //│ (type $B_init (func (param $this (ref null any)) (result (ref null any)))) //│ (type $B_ctor (func (result (ref null any)))) -//│ (type $callF (func (param $a (ref null any)) (result (ref null any)))) +//│ (type $callF (func (param $a (ref $A)) (result (ref null any)))) //│ (type $plus_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) //│ (type $times_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) //│ (type $entry (func (result (ref null any)))) @@ -91,7 +91,7 @@ callF(B()) //│ (local.get $x) //│ (ref.i31 //│ (i32.const 2))))) -//│ (func $callF (export "callF") (type $callF) (param $a (ref null any)) (result (ref null any)) +//│ (func $callF (export "callF") (type $callF) (param $a (ref $A)) (result (ref null any)) //│ (local $receiver (ref null any)) //│ (return //│ (block (result (ref null any)) From 47e729cce4ee4b74468eafe708247f73f353ba74 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 2 Jul 2026 16:35:19 +0800 Subject: [PATCH 47/48] wasm: Return concrete ref-type from constructors --- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 56 ++++++++++--------- .../shared/src/test/mlscript/wasm/Basics.mls | 29 +++++----- .../test/mlscript/wasm/ClassInheritance.mls | 12 ++-- .../src/test/mlscript/wasm/ClassMethods.mls | 6 +- .../src/test/mlscript/wasm/Matching.mls | 12 ++-- .../src/test/mlscript/wasm/ScopedLocals.mls | 6 +- .../src/test/mlscript/wasm/SingletonUnit.mls | 11 ++-- .../src/test/mlscript/wasm/Singletons.mls | 18 +++--- .../src/test/mlscript/wasm/VirtualMethods.mls | 12 ++-- 9 files changed, 80 insertions(+), 82 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 1a5fe13d23..67f524c2f9 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -47,9 +47,9 @@ extension (sym: ValueSymbol) import Ctx.ctx sym match case isym: InnerSymbol => - val structSym = isym.asBlkMember orElse: - Option.when(isym eq State.unitSymbol): - State.unitBlockMemberSymbol + val structSym = + if isym eq State.unitSymbol then S(State.unitBlockMemberSymbol) + else isym.asBlkMember structSym.flatMap(ctx.getType).map(RefType(_, nullable = false)).getOrElse(RefType.anyref) case _ => RefType.anyref @@ -351,9 +351,9 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val ctorCall = call( funcidx = ctx.getFunc_!(clsLikeDefn.sym), operands = Seq.empty, - returnTypes = Seq(Result(RefType.anyref)), + returnTypes = Seq(Result(globalTy)), ) - ctx.addSingletonInitAction(global.set(globalIdx, ref.cast(ctorCall, globalTy))) + ctx.addSingletonInitAction(global.set(globalIdx, ctorCall)) end registerSingletonInit /** Collects only top-level class definitions in `block`. */ @@ -535,12 +535,13 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: defn: ClsLikeDefn, suffix: Str, params: Seq[ValueSymbol -> SymIdx], + results: Seq[Result], )(using Ctx, Raise): TypeIdx = ctx.addType(TypeInfo( sym = TempSymbol(N, erasedType = N, defn.sym.nme), FunctionType( params = params.map(p => WasmParam(p._2, RefType.anyref)), - results = Seq(Result(RefType.anyref)), + results = results, ), objectTag = N, wrapId = N -> S(suffix), @@ -573,17 +574,19 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: defn: ClsLikeDefn, suffix: Str, params: Seq[ValueSymbol -> SymIdx], + results: Seq[Result], sym: BlockMemberSymbol, exportName: Opt[Str], )(using Ctx, Raise): Unit = - val funcTy = declareClassFuncType(defn, suffix, params) - predeclareClassFuncWithType(defn, suffix, params, sym, exportName, funcTy) + val funcTy = declareClassFuncType(defn, suffix, params, results) + predeclareClassFuncWithType(defn, suffix, params, results, sym, exportName, funcTy) /** Registers a placeholder class-associated function using a predeclared Wasm function type. */ private def predeclareClassFuncWithType( defn: ClsLikeDefn, suffix: Str, params: Seq[ValueSymbol -> SymIdx], + resultTypes: Seq[Result], sym: BlockMemberSymbol, exportName: Opt[Str], funcTy: TypeIdx, @@ -593,7 +596,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: wrapId = if sym.asClsOrMod.isDefined then (N -> S("ctor")) else (S(defn.sym.nme) -> N), typeUse = TypeUse(funcTy), params = params, - resultTypes = Seq(Result(RefType.anyref)), + resultTypes = resultTypes, locals = Seq.empty, body = ref.`null`(ctx.getType_!(defn.sym)), exportName = exportName, @@ -620,17 +623,18 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val initParams = (defn.isym -> SymIdx("this")) +: pl.params.map: p => p.sym -> SymIdx(p.sym.nme) - predeclareClassFunc(defn, "init", initParams, initFuncSym(defn.sym), N) + predeclareClassFunc(defn, "init", initParams, Seq(Result(RefType.anyref)), initFuncSym(defn.sym), N) /** Declares one top-level class constructor. */ private def predeclareClassConstructor(defn: ClsLikeDefn)(using Ctx, Raise): Unit = + val typeIdx = ctx.getType_!(defn.sym) val ctorParams = classCtorParamList(defn).params.map: p => p.sym -> SymIdx(p.sym.nme) val ctorExportName = defn.sym .optionIf: sym => !(defn.k is syntax.Obj) && sym.nameIsMeaningful .map(_.nme) - predeclareClassFunc(defn, "ctor", ctorParams, defn.sym, ctorExportName) + predeclareClassFunc(defn, "ctor", ctorParams, Seq(Result(RefType(typeIdx, nullable = false))), defn.sym, ctorExportName) /** Registers all Wasm pre-declarations needed for one top-level class, in dependency order. */ private def predeclareClass(defn: ClsLikeDefn)(using Ctx, Raise, SessionExportCtx): Unit = @@ -844,12 +848,13 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ownerCls, methodDefn.sym.nme, methodParams, + Seq(Result(RefType.anyref)), methodDefn.sym, N, virtualMethodFuncType(methodParams.size), ) case N => - predeclareClassFunc(ownerCls, methodDefn.sym.nme, methodParams, methodDefn.sym, N) + predeclareClassFunc(ownerCls, methodDefn.sym.nme, methodParams, Seq(Result(RefType.anyref)), methodDefn.sym, N) /** Declares placeholders for all methods on one top-level class. */ private def predeclareClassMethods(defn: ClsLikeDefn)(using Ctx, Raise): Unit = @@ -1505,17 +1510,16 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: Ls(msg"Class path for an Instantiate(...) expression must be resolved" -> cls.toLoc), extraInfo = S(s"Block IR of `cls` expression: ${cls.toString}"), ) - val ctorClsBlkSym = ctorClsSym.asBlkMember match - case S(sym) => sym - case N => lastWords( - s"Expected resolved class for an Instantiate(...) expression to be a BlockMemberSymbol, but got ${ - ctorClsSym.getClass.getName - }", - ) - val ctorFuncIdx = ctx.getFunc(ctorClsBlkSym) match - case S(idx) => idx - case N => lastWords(s"Missing constructor definition for class ${ctorClsBlkSym.toString}") - call(funcidx = ctorFuncIdx, as.map(argument), Seq(Result(RefType.anyref))) + val ctorClsBlkSym = ctorClsSym.asBlkMember.getOrElse: + lastWords: + s"Expected resolved class for an Instantiate(...) expression to be a BlockMemberSymbol, but got ${ + ctorClsSym.getClass.getName + }" + val ctorFuncIdx = ctx.getFunc(ctorClsBlkSym).getOrElse: + lastWords(s"Missing constructor definition for class ${ctorClsBlkSym.toString}") + val ctorClsTypeIdx = ctx.getType(ctorClsBlkSym).getOrElse: + lastWords(s"Missing class definition for class ${ctorClsBlkSym.toString}") + call(funcidx = ctorFuncIdx, as.map(argument), Seq(Result(RefType(ctorClsTypeIdx, nullable = false)))) case Tuple(mut, elems) => val tupleValues = elems.map(argument) @@ -1921,9 +1925,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: Seq( local.set(thisVar, struct.`new`(typeref, instanceFields)), drop(initCall), - // TODO(Derppening): Restore once we fix custom reftypes - // `return`(S(local.get(thisVar, RefType(typeref, nullable = false)))), - `return`(S(local.get(thisVar, RefType.anyref))), + `return`(S(local.get(thisVar, RefType(typeref, nullable = false)))), ).mergeAsBlock_! val predeclaredInit = ctx.getFuncInfo_!(initFuncRef) @@ -2001,7 +2003,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: funcType = FunctionType( SignatureType( params = ctorFnCtx.params.map(p => WasmParam(p._2, RefType.anyref)), - results = Seq(Result(RefType.anyref)), + results = ctorCode.resultTypes.map(ty => Result(ty.asValType_!)), ), ), )) diff --git a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls index 5a532c5cf6..6a7d578467 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls @@ -152,7 +152,7 @@ class Foo(val a) //│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $a (mut (ref null any)))))) //│ (type $Foo_init (func (param $this (ref null any)) (param $a (ref null any)) (result (ref null any)))) -//│ (type $Foo_ctor (func (param $a (ref null any)) (result (ref null any)))) +//│ (type $Foo_ctor (func (param $a (ref null any)) (result (ref $Foo)))) //│ (type $entry (func (result (ref null any)))) //│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo //│ (i32.const 1) @@ -165,9 +165,9 @@ class Foo(val a) //│ (local.get $a)) //│ (return //│ (local.get $this)))) -//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $a (ref null any)) (result (ref null any)) +//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $a (ref null any)) (result (ref $Foo)) //│ (local $this (ref $Foo)) -//│ (block (result (ref null any)) +//│ (block (result (ref $Foo)) //│ (local.set $this //│ (struct.new $Foo //│ (global.get $Foo_typeinfo) @@ -215,7 +215,7 @@ getX(Foo(42)) //│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any)))))) //│ (type $Foo_init (func (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)))) -//│ (type $Foo_ctor (func (param $x (ref null any)) (result (ref null any)))) +//│ (type $Foo_ctor (func (param $x (ref null any)) (result (ref $Foo)))) //│ (type $getX (func (param $f (ref $Foo)) (result (ref null any)))) //│ (type $entry (func (result (ref null any)))) //│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo @@ -229,9 +229,9 @@ getX(Foo(42)) //│ (local.get $x)) //│ (return //│ (local.get $this)))) -//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $x (ref null any)) (result (ref null any)) +//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $x (ref null any)) (result (ref $Foo)) //│ (local $this (ref $Foo)) -//│ (block (result (ref null any)) +//│ (block (result (ref $Foo)) //│ (local.set $this //│ (struct.new $Foo //│ (global.get $Foo_typeinfo) @@ -276,7 +276,7 @@ class Foo(val x) with //│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any))) (field $y (mut (ref null any)))))) //│ (type $Foo_init (func (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)))) -//│ (type $Foo_ctor (func (param $x (ref null any)) (result (ref null any)))) +//│ (type $Foo_ctor (func (param $x (ref null any)) (result (ref $Foo)))) //│ (type $entry (func (result (ref null any)))) //│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo //│ (i32.const 1) @@ -295,9 +295,9 @@ class Foo(val x) with //│ (local.get $this)))) //│ (return //│ (local.get $this)))) -//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $x (ref null any)) (result (ref null any)) +//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $x (ref null any)) (result (ref $Foo)) //│ (local $this (ref $Foo)) -//│ (block (result (ref null any)) +//│ (block (result (ref $Foo)) //│ (local.set $this //│ (struct.new $Foo //│ (global.get $Foo_typeinfo) @@ -339,7 +339,7 @@ O.y //│ (type $O_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $O (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any))) (field $y (mut (ref null any)))))) //│ (type $O_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $O_ctor (func (result (ref null any)))) +//│ (type $O_ctor (func (result (ref $O)))) //│ (type $entry (func (result (ref null any)))) //│ (type $start (func)) //│ (global $O_typeinfo (export "O_typeinfo") (ref $O_typeinfo) (struct.new $O_typeinfo @@ -361,9 +361,9 @@ O.y //│ (local.get $this)))) //│ (return //│ (local.get $this)))) -//│ (func $O_ctor (type $O_ctor) (result (ref null any)) +//│ (func $O_ctor (type $O_ctor) (result (ref $O)) //│ (local $this (ref $O)) -//│ (block (result (ref null any)) +//│ (block (result (ref $O)) //│ (local.set $this //│ (struct.new $O //│ (global.get $O_typeinfo) @@ -376,8 +376,7 @@ O.y //│ (local.get $this)))) //│ (func $start (type $start) //│ (global.set $O$inst -//│ (ref.cast (ref null $O) -//│ (call $O_ctor)))) +//│ (call $O_ctor))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) //│ (return //│ (struct.get $O $y @@ -397,5 +396,5 @@ fun bar() = 42 fun foo() = bar foo()() //│ ╔══[COMPILATION ERROR] Returning function instances is not supported -//│ ║ l.397: fun foo() = bar +//│ ║ l.396: fun foo() = bar //│ ╙── ^^^ diff --git a/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls b/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls index 1f871e6cb2..f46f240ee1 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls @@ -31,11 +31,11 @@ c.Parent#x + (c is Parent) //│ (type $Parent_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Parent (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any)))))) //│ (type $Parent_init (func (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)))) -//│ (type $Parent_ctor (func (param $x (ref null any)) (result (ref null any)))) +//│ (type $Parent_ctor (func (param $x (ref null any)) (result (ref $Parent)))) //│ (type $Child_typeinfo (sub $Parent_typeinfo (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Child (sub $Parent (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any))) (field $y (mut (ref null any)))))) //│ (type $Child_init (func (param $this (ref null any)) (param $y (ref null any)) (result (ref null any)))) -//│ (type $Child_ctor (func (param $y (ref null any)) (result (ref null any)))) +//│ (type $Child_ctor (func (param $y (ref null any)) (result (ref $Child)))) //│ (type $plus_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) //│ (type $entry (func (result (ref null any)))) //│ (import "system" "plus_impl" (func $plus_impl (type $plus_impl))) @@ -54,9 +54,9 @@ c.Parent#x + (c is Parent) //│ (local.get $x)) //│ (return //│ (local.get $this)))) -//│ (func $Parent_ctor (export "Parent") (type $Parent_ctor) (param $x (ref null any)) (result (ref null any)) +//│ (func $Parent_ctor (export "Parent") (type $Parent_ctor) (param $x (ref null any)) (result (ref $Parent)) //│ (local $this (ref $Parent)) -//│ (block (result (ref null any)) +//│ (block (result (ref $Parent)) //│ (local.set $this //│ (struct.new $Parent //│ (global.get $Parent_typeinfo) @@ -85,9 +85,9 @@ c.Parent#x + (c is Parent) //│ (local.get $y)) //│ (return //│ (local.get $this)))) -//│ (func $Child_ctor (export "Child") (type $Child_ctor) (param $y (ref null any)) (result (ref null any)) +//│ (func $Child_ctor (export "Child") (type $Child_ctor) (param $y (ref null any)) (result (ref $Child)) //│ (local $this (ref $Child)) -//│ (block (result (ref null any)) +//│ (block (result (ref $Child)) //│ (local.set $this //│ (struct.new $Child //│ (global.get $Child_typeinfo) diff --git a/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls b/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls index f042277f81..b960180970 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls @@ -15,7 +15,7 @@ a.A#get() + A(2).get() //│ (type $A_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $A (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any)))))) //│ (type $A_init (func (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)))) -//│ (type $A_ctor (func (param $x (ref null any)) (result (ref null any)))) +//│ (type $A_ctor (func (param $x (ref null any)) (result (ref $A)))) //│ (type $A_get (func (param $this (ref null any)) (result (ref null any)))) //│ (type $plus_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) //│ (type $entry (func (result (ref null any)))) @@ -32,9 +32,9 @@ a.A#get() + A(2).get() //│ (local.get $x)) //│ (return //│ (local.get $this)))) -//│ (func $A_ctor (export "A") (type $A_ctor) (param $x (ref null any)) (result (ref null any)) +//│ (func $A_ctor (export "A") (type $A_ctor) (param $x (ref null any)) (result (ref $A)) //│ (local $this (ref $A)) -//│ (block (result (ref null any)) +//│ (block (result (ref $A)) //│ (local.set $this //│ (struct.new $A //│ (global.get $A_typeinfo) diff --git a/hkmc2/shared/src/test/mlscript/wasm/Matching.mls b/hkmc2/shared/src/test/mlscript/wasm/Matching.mls index 1243edf2ea..647dfe5f8f 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Matching.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Matching.mls @@ -48,11 +48,11 @@ if Bar(true) is //│ (type $Bar_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Bar (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $y (mut (ref null any)))))) //│ (type $Bar_init (func (param $this (ref null any)) (param $y (ref null any)) (result (ref null any)))) -//│ (type $Bar_ctor (func (param $y (ref null any)) (result (ref null any)))) +//│ (type $Bar_ctor (func (param $y (ref null any)) (result (ref $Bar)))) //│ (type $Baz_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Baz (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $z (mut (ref null any)))))) //│ (type $Baz_init (func (param $this (ref null any)) (param $z (ref null any)) (result (ref null any)))) -//│ (type $Baz_ctor (func (param $z (ref null any)) (result (ref null any)))) +//│ (type $Baz_ctor (func (param $z (ref null any)) (result (ref $Baz)))) //│ (type $entry (func (result (ref null any)))) //│ (global $Bar_typeinfo (export "Bar_typeinfo") (ref $Bar_typeinfo) (struct.new $Bar_typeinfo //│ (i32.const 1) @@ -68,9 +68,9 @@ if Bar(true) is //│ (local.get $y)) //│ (return //│ (local.get $this)))) -//│ (func $Bar_ctor (export "Bar") (type $Bar_ctor) (param $y (ref null any)) (result (ref null any)) +//│ (func $Bar_ctor (export "Bar") (type $Bar_ctor) (param $y (ref null any)) (result (ref $Bar)) //│ (local $this (ref $Bar)) -//│ (block (result (ref null any)) +//│ (block (result (ref $Bar)) //│ (local.set $this //│ (struct.new $Bar //│ (global.get $Bar_typeinfo) @@ -89,9 +89,9 @@ if Bar(true) is //│ (local.get $z)) //│ (return //│ (local.get $this)))) -//│ (func $Baz_ctor (export "Baz") (type $Baz_ctor) (param $z (ref null any)) (result (ref null any)) +//│ (func $Baz_ctor (export "Baz") (type $Baz_ctor) (param $z (ref null any)) (result (ref $Baz)) //│ (local $this (ref $Baz)) -//│ (block (result (ref null any)) +//│ (block (result (ref $Baz)) //│ (local.set $this //│ (struct.new $Baz //│ (global.get $Baz_typeinfo) diff --git a/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls b/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls index 186fd5d800..1fd6c40b0c 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls @@ -149,7 +149,7 @@ class Foo(val a, val b) //│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $a (mut (ref null any))) (field $b (mut (ref null any)))))) //│ (type $Foo_init (func (param $this (ref null any)) (param $a (ref null any)) (param $b (ref null any)) (result (ref null any)))) -//│ (type $Foo_ctor (func (param $a (ref null any)) (param $b (ref null any)) (result (ref null any)))) +//│ (type $Foo_ctor (func (param $a (ref null any)) (param $b (ref null any)) (result (ref $Foo)))) //│ (type $entry (func (result (ref null any)))) //│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo //│ (i32.const 1) @@ -166,9 +166,9 @@ class Foo(val a, val b) //│ (local.get $b)) //│ (return //│ (local.get $this)))) -//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $a (ref null any)) (param $b (ref null any)) (result (ref null any)) +//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $a (ref null any)) (param $b (ref null any)) (result (ref $Foo)) //│ (local $this (ref $Foo)) -//│ (block (result (ref null any)) +//│ (block (result (ref $Foo)) //│ (local.set $this //│ (struct.new $Foo //│ (global.get $Foo_typeinfo) diff --git a/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls b/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls index adb7c68bcf..51b3f24893 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls @@ -11,7 +11,7 @@ //│ (type $Unit_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Unit (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) //│ (type $Unit_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $Unit_ctor (func (result (ref null any)))) +//│ (type $Unit_ctor (func (result (ref $Unit)))) //│ (type $entry (func (result (ref null any)))) //│ (type $start (func)) //│ (global $Unit_typeinfo (export "Unit_typeinfo") (ref $Unit_typeinfo) (struct.new $Unit_typeinfo @@ -21,9 +21,9 @@ //│ (func $Unit_init (type $Unit_init) (param $this (ref null any)) (result (ref null any)) //│ (return //│ (local.get $this))) -//│ (func $Unit_Unit (type $Unit_ctor) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Unit_Unit (type $Unit_ctor) (result (ref $Unit)) +//│ (local $this (ref $Unit)) +//│ (block (result (ref $Unit)) //│ (local.set $this //│ (struct.new $Unit //│ (global.get $Unit_typeinfo))) @@ -34,8 +34,7 @@ //│ (local.get $this)))) //│ (func $start (type $start) //│ (global.set $Unit$inst -//│ (ref.cast (ref null $Unit) -//│ (call $Unit_Unit)))) +//│ (call $Unit_Unit))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) //│ (return //│ (global.get $Unit$inst))) diff --git a/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls b/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls index 9891c58f6e..d0ca4f6efa 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls @@ -16,11 +16,11 @@ Bar.y //│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any)))))) //│ (type $Foo_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $Foo_ctor (func (result (ref null any)))) +//│ (type $Foo_ctor (func (result (ref $Foo)))) //│ (type $Bar_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Bar (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $y (mut (ref null any)))))) //│ (type $Bar_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $Bar_ctor (func (result (ref null any)))) +//│ (type $Bar_ctor (func (result (ref $Bar)))) //│ (type $entry (func (result (ref null any)))) //│ (type $start (func)) //│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo @@ -40,9 +40,9 @@ Bar.y //│ (i32.const 1))) //│ (return //│ (local.get $this)))) -//│ (func $Foo_ctor (type $Foo_ctor) (result (ref null any)) +//│ (func $Foo_ctor (type $Foo_ctor) (result (ref $Foo)) //│ (local $this (ref $Foo)) -//│ (block (result (ref null any)) +//│ (block (result (ref $Foo)) //│ (local.set $this //│ (struct.new $Foo //│ (global.get $Foo_typeinfo) @@ -61,9 +61,9 @@ Bar.y //│ (i32.const 2))) //│ (return //│ (local.get $this)))) -//│ (func $Bar_ctor (type $Bar_ctor) (result (ref null any)) +//│ (func $Bar_ctor (type $Bar_ctor) (result (ref $Bar)) //│ (local $this (ref $Bar)) -//│ (block (result (ref null any)) +//│ (block (result (ref $Bar)) //│ (local.set $this //│ (struct.new $Bar //│ (global.get $Bar_typeinfo) @@ -76,11 +76,9 @@ Bar.y //│ (func $start (type $start) //│ (block //│ (global.set $Foo$inst -//│ (ref.cast (ref null $Foo) -//│ (call $Foo_ctor))) +//│ (call $Foo_ctor)) //│ (global.set $Bar$inst -//│ (ref.cast (ref null $Bar) -//│ (call $Bar_ctor))))) +//│ (call $Bar_ctor)))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) //│ (return //│ (struct.get $Bar $y diff --git a/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls b/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls index ecfc3a5453..e497d1c179 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls @@ -28,11 +28,11 @@ callF(B()) //│ (type $A_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase)) (field $slot0 (mut (ref null $virtual2)))))) //│ (type $A (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) //│ (type $A_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $A_ctor (func (result (ref null any)))) +//│ (type $A_ctor (func (result (ref $A)))) //│ (type $B_typeinfo (sub $A_typeinfo (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase)) (field $slot0 (mut (ref null $virtual2)))))) //│ (type $B (sub $A (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) //│ (type $B_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $B_ctor (func (result (ref null any)))) +//│ (type $B_ctor (func (result (ref $B)))) //│ (type $callF (func (param $a (ref $A)) (result (ref null any)))) //│ (type $plus_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) //│ (type $times_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) @@ -50,9 +50,9 @@ callF(B()) //│ (func $A_init (type $A_init) (param $this (ref null any)) (result (ref null any)) //│ (return //│ (local.get $this))) -//│ (func $A_ctor (export "A") (type $A_ctor) (result (ref null any)) +//│ (func $A_ctor (export "A") (type $A_ctor) (result (ref $A)) //│ (local $this (ref $A)) -//│ (block (result (ref null any)) +//│ (block (result (ref $A)) //│ (local.set $this //│ (struct.new $A //│ (global.get $A_typeinfo))) @@ -74,9 +74,9 @@ callF(B()) //│ (local.get $this))) //│ (return //│ (local.get $this)))) -//│ (func $B_ctor (export "B") (type $B_ctor) (result (ref null any)) +//│ (func $B_ctor (export "B") (type $B_ctor) (result (ref $B)) //│ (local $this (ref $B)) -//│ (block (result (ref null any)) +//│ (block (result (ref $B)) //│ (local.set $this //│ (struct.new $B //│ (global.get $B_typeinfo))) From 3697356902022676e40c1f3a41a3db9ddf109df8 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 2 Jul 2026 16:53:36 +0800 Subject: [PATCH 48/48] wasm: Further restrict types of locals --- .../hkmc2/codegen/wasm/text/WatBuilder.scala | 25 ++++++++++++++++--- .../shared/src/test/mlscript/wasm/Basics.mls | 12 ++++----- .../test/mlscript/wasm/ClassInheritance.mls | 2 +- .../src/test/mlscript/wasm/Matching.mls | 4 +-- .../src/test/mlscript/wasm/ScopedLocals.mls | 5 ++-- 5 files changed, 31 insertions(+), 17 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 67f524c2f9..4532dc5ac5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -51,6 +51,8 @@ extension (sym: ValueSymbol) if isym eq State.unitSymbol then S(State.unitBlockMemberSymbol) else isym.asBlkMember structSym.flatMap(ctx.getType).map(RefType(_, nullable = false)).getOrElse(RefType.anyref) + case s: HasErasedType => + s.erasedType.flatMap(_.wasmType).getOrElse(RefType.anyref) case _ => RefType.anyref /** The Wasm reference type a parameter slot for `sym` should be declared with, if typed parameters are enabled. */ @@ -172,6 +174,14 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: require(expr.resultTypes.size == 1, "expected single-result expression for cast") if expr.resultType.contains(target) then expr else ref.cast(expr, target) + /** Casts an expression to `target` type if the result type is a supertype of `target`. */ + private def downcastConserve(expr: Expr, target: RefType): Expr = + require(expr.resultTypes.size == 1, "expected single-result expression for cast") + target match + case rt: RefType if rt.heapType =/= HeapType.Any && expr.resultType.contains(RefType.anyref) => + castConserve(expr, rt) + case _ => expr + /** Casts each argument in `wasmArgs` down to the corresponding declared parameter type read from `funcTypeInfo`, * narrowing `anyref` -> a concrete typed parameter. * @@ -1700,11 +1710,14 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case Assign(l: ValueSymbol, r, rst) => val lExpr = getVar(l, l.toLoc) val rExpr = result(r) + val rExprCasted = lExpr.resultType match + case S(rt: RefType) => downcastConserve(rExpr, rt) + case _ => rExpr val assignExpr = lExpr.mnemonicPrefix match case S("global") => - global.set(lExpr.instrargs(0).asInstanceOf[GlobalIdx], rExpr) + global.set(lExpr.instrargs(0).asInstanceOf[GlobalIdx], rExprCasted) case S("local") => - local.set(lExpr.instrargs(0).asInstanceOf[LocalIdx], rExpr) + local.set(lExpr.instrargs(0).asInstanceOf[LocalIdx], rExprCasted) case _ => lastWords( s"Expected `global.*` or `local.*` when compiling instruction for `$l`, but got ${lExpr.mnemonic}", @@ -1786,11 +1799,15 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case N => val localStorageSym = defn.sym val symExpr = getVar(localStorageSym, localStorageSym.toLoc) + val pExpr = result(p) + val pExprCasted = symExpr.resultType match + case S(rt: RefType) => downcastConserve(pExpr, rt) + case _ => pExpr val defineExpr = symExpr.mnemonicPrefix match case S("global") => - global.set(symExpr.instrargs(0).asInstanceOf[GlobalIdx], result(p)) + global.set(symExpr.instrargs(0).asInstanceOf[GlobalIdx], pExprCasted) case S("local") => - local.set(symExpr.instrargs(0).asInstanceOf[LocalIdx], result(p)) + local.set(symExpr.instrargs(0).asInstanceOf[LocalIdx], pExprCasted) case _ => lastWords( s"Expected `global.*` or `local.*` when compiling definition for `$sym`, but got ${symExpr.mnemonic}", diff --git a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls index 6a7d578467..685f701941 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls @@ -179,7 +179,7 @@ class Foo(val a) //│ (return //│ (local.get $this)))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) -//│ (local $tmp (ref null any)) +//│ (local $tmp (ref $Foo)) //│ (block (result (ref null any)) //│ (local.set $tmp //│ (call $Foo_ctor @@ -187,8 +187,7 @@ class Foo(val a) //│ (i32.const 42)))) //│ (return //│ (struct.get $Foo $a -//│ (ref.cast (ref $Foo) -//│ (local.get $tmp)))))) +//│ (local.get $tmp))))) //│ (elem $Foo_init declare func $Foo_init) //│ (elem $Foo_ctor declare func $Foo_ctor) //│ (elem $entry declare func $entry)) @@ -310,7 +309,7 @@ class Foo(val x) with //│ (return //│ (local.get $this)))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) -//│ (local $tmp (ref null any)) +//│ (local $tmp (ref $Foo)) //│ (block (result (ref null any)) //│ (local.set $tmp //│ (call $Foo_ctor @@ -318,8 +317,7 @@ class Foo(val x) with //│ (i32.const 42)))) //│ (return //│ (struct.get $Foo $y -//│ (ref.cast (ref $Foo) -//│ (local.get $tmp)))))) +//│ (local.get $tmp))))) //│ (elem $Foo_init declare func $Foo_init) //│ (elem $Foo_ctor declare func $Foo_ctor) //│ (elem $entry declare func $entry)) @@ -396,5 +394,5 @@ fun bar() = 42 fun foo() = bar foo()() //│ ╔══[COMPILATION ERROR] Returning function instances is not supported -//│ ║ l.396: fun foo() = bar +//│ ║ l.394: fun foo() = bar //│ ╙── ^^^ diff --git a/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls b/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls index f46f240ee1..a00975aa8e 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls @@ -103,7 +103,7 @@ c.Parent#x + (c is Parent) //│ (local $matchRes (ref null any)) //│ (local $currentTypeInfo (ref null any)) //│ (local $targetTypeInfo (ref null any)) -//│ (local $typeInfoMatch (ref null any)) +//│ (local $typeInfoMatch (ref null i31)) //│ (block (result (ref null any)) //│ (global.set $c //│ (call $Child_ctor diff --git a/hkmc2/shared/src/test/mlscript/wasm/Matching.mls b/hkmc2/shared/src/test/mlscript/wasm/Matching.mls index 647dfe5f8f..0c5a5a1c9a 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Matching.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Matching.mls @@ -108,11 +108,11 @@ if Bar(true) is //│ (local $matchRes (ref null any)) //│ (local $currentTypeInfo (ref null any)) //│ (local $targetTypeInfo (ref null any)) -//│ (local $typeInfoMatch (ref null any)) +//│ (local $typeInfoMatch (ref null i31)) //│ (local $matchRes1 (ref null any)) //│ (local $currentTypeInfo1 (ref null any)) //│ (local $targetTypeInfo1 (ref null any)) -//│ (local $typeInfoMatch1 (ref null any)) +//│ (local $typeInfoMatch1 (ref null i31)) //│ (block (result (ref null any)) //│ (local.set $scrut //│ (call $Bar_ctor diff --git a/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls b/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls index 1fd6c40b0c..4def17a2d3 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls @@ -182,7 +182,7 @@ class Foo(val a, val b) //│ (return //│ (local.get $this)))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) -//│ (local $tmp (ref null any)) +//│ (local $tmp (ref $Foo)) //│ (block (result (ref null any)) //│ (local.set $tmp //│ (call $Foo_ctor @@ -192,8 +192,7 @@ class Foo(val a, val b) //│ (i32.const 1)))) //│ (return //│ (struct.get $Foo $a -//│ (ref.cast (ref $Foo) -//│ (local.get $tmp)))))) +//│ (local.get $tmp))))) //│ (elem $Foo_init declare func $Foo_init) //│ (elem $Foo_ctor declare func $Foo_ctor) //│ (elem $entry declare func $entry))