diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index a991599478..3d425e295d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -12,7 +12,8 @@ import hkmc2.semantics.{Term => st} import syntax.{Literal, Tree, SpreadKind, Keyword} import semantics.* import semantics.Term.* -import sem.Elaborator.State +import sem.Elaborator.{Ctx, State, ctx} +import sourcecode.{FileName, Line} /* Important design notes. @@ -45,7 +46,7 @@ case class Program( type SimpleSymbol = LocalVarSymbol | BuiltinSymbol /** Symbol that can be used as the left-hand side of an `Assign`. */ -type Assignable = LocalVarSymbol | NoSymbol +type Assignable = LocalVarSymbol | NoSymbol.type /** Symbols that `Scoped` introduces as block-local bindings. * This deliberately excludes things like `TermSymbol`s, which never need to be scoped. @@ -552,7 +553,7 @@ object HandleBlock: N, sym, PlainParamList(Param(FldFlags.empty, handler.resumeSym, N, Modulefulness.none) :: Nil) :: Nil, handler.body )(N, annotations = Nil) - val rSym = TempSymbol(N, "suspendRes") + val rSym = TempSymbol(N, erasedType = N, "suspendRes") FunDefn.withFreshSymbol( S(cls), handler.sym, @@ -570,7 +571,7 @@ object HandleBlock: N, Nil, S(par), handlerMtds, Nil, Nil, // Apparently, the lifter is not happy with any assignment in the preCtor... - Assign(State.noSymbol, Call(State.builtinOpsMap("super").asSimpleRef, args.map(_.asArg) ne_:: Nil)(CallMetadata.mlsFunWithEffect), End()), + Assign(NoSymbol, Call(State.builtinOpsMap("super").asSimpleRef, args.map(_.asArg) ne_:: Nil)(CallMetadata.mlsFunWithEffect), End()), End(), N, N, @@ -699,7 +700,7 @@ final case class FunDefn( object FunDefn: def withFreshSymbol(owner: Opt[InnerSymbol], sym: BlockMemberSymbol, params: Ls[ParamList], body: Block)(configOverride: Opt[Config], annotations: Ls[Annot])(using State) = - val tSym = TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme)) + val tSym = TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme), erasedType = N) sym.tsym = S(tSym) FunDefn(owner, sym, tSym, params, body)(configOverride, annotations) @@ -725,7 +726,7 @@ object ValDefn: annotations: Ls[Annot], )(using State) : ValDefn = - ValDefn(tsym = TermSymbol(k, owner, Tree.Ident(sym.nme)), sym = sym, rhs = rhs)(configOverride, annotations) + ValDefn(tsym = TermSymbol(k, owner, Tree.Ident(sym.nme), erasedType = rhs.erasedType), sym, rhs)(configOverride, annotations) /* @@ -857,9 +858,121 @@ enum Case: case Tup(_, _) => Set.empty case Field(_, _) => Set.empty +/** A primitive type of the block IR. */ +enum PrimitiveType: + case Unit, Int, Int31, Num, Str, Bool, Array + + /** The symbol for this primitive type. */ + def sym(using Ctx, State): ClassLikeSymbol = this match + case Unit => summon[State].unitSymbol + case Int => ctx.builtins.Int + case Int31 => ctx.builtins.Int31 + case Num => ctx.builtins.Num + case Str => ctx.builtins.Str + case Bool => ctx.builtins.Bool + case Array => ctx.builtins.Array + +object ErasedType: + def ObjectRef: ErasedType.AnyRef = AnyRef(rsc = false, NoSymbol) + + /** Maps a [[`ClassLikeSymbol`]] into the canonical [[`ErasedType`]]. */ + def fromClsLikeSymbol(csym: ClassLikeSymbol, rsc: Bool)(using Ctx, State): ErasedType = + PrimitiveType.values.find(_.sym === csym) match + case _ if csym === ctx.builtins.Object => ObjectRef + case S(prim) => ErasedType.Primitive(prim) + case _ => ErasedType.AnyRef(rsc, csym) + + /** Joins two possibly-unknown erased types and returning the result. + * + * - If both sides are known and equal, returns that known type. + * - If both sides are known and unequal, returns the top type [[`ErasedType.ObjectRef`]]. + * - If either side is unknown (`N`), returns `N` - Representing a possibly-unknown type. + */ + // TODO(Derppening): Widen conflicting knowns to their common ancestor once erased types track subtyping, + // rather than going directly to the top type. + def join(lhs: Opt[ErasedType], rhs: Opt[ErasedType]): Opt[ErasedType] = (lhs, rhs) match + case (N, _) | (_, N) => N + case (S(l), S(r)) => if l == r then S(l) else S(ObjectRef) + +/** A generics-erased type of the Block IR. */ +enum ErasedType: + /** + * An reference to a class-like symbol. + * + * If `csym` is `NoSymbol`, this represents the top type (`Object`). + * + * - `rsc` is true if this reference is a resource class. + */ + case AnyRef(rsc: Bool, csym: ClassLikeSymbol | NoSymbol.type) + + /** A reference to a function of a possibly-known shape. */ + case FuncRef(params: Ls[Opt[ErasedType]], ret: Opt[ErasedType]) + + /** An primitive type. */ + case Primitive(prim: PrimitiveType) + + /** The symbol for this erased type. */ + def sym(using Ctx, State): ClassLikeSymbol = this match + case AnyRef(_, csym: ClassLikeSymbol) => csym + case AnyRef(_, NoSymbol) => ctx.builtins.Object + case FuncRef(_, _) => ctx.builtins.Function + case Primitive(prim) => prim.sym + +/** Trait representing a Block IR element that has an [[`ErasedType`]]. */ +trait HasErasedType: + /** The [[`ErasedType`]] of this element, or `N` if the erased type is not known. */ + def erasedType: Opt[ErasedType] + + /** Similar to `erasedType`, but coerces to the top type if the specific erased type is not known. */ + def erasedType_! : ErasedType = erasedType.getOrElse(ErasedType.ObjectRef) + +/** A [[`HasErasedType`]] whose erased type can be populated exactly once post-construction. */ +trait HasOnceMutableErasedType extends HasErasedType: + // Implementation Note: Provided for overriding classes to implement `erasedType` directly as an `override var` + def erasedType_=(newType: Opt[ErasedType]): Unit + + /** Populates the erased type, or raises a soft assertion if the type was already populated. */ + def populateErasedType(newType: ErasedType)(using Line, FileName, Raise): Unit = + // TODO(Derppening): Restore `erasedType.isEmpty` once JS sanitization is converted into a pass, allowing us to + // only lower each program once + softAssert(erasedType.forall(_ == newType), s"Cannot refine already-refined erased type $erasedType to $newType") + if erasedType.isEmpty then erasedType = S(newType) + +/** A [[`HasOnceMutableErasedType`]] whose erased type can additionally be assigned multiple times, joining each + * observed type into the running erased type via [[`ErasedType.join`]]. */ +trait HasManyMutableErasedType extends HasOnceMutableErasedType: + /** Tracks whether [[`observeErasedTypeAssign`]] has seen at least one assignment to this symbol. Needed to + * distinguish a never-observed symbol (whose `N` erased type is the join unit, to be seeded by the first + * observation) from an observed-but-poisoned one (whose `N` erased type is the absorbing top). */ + private var erasedTypeObserved: Bool = false + + /** + * Observes an assignment of a value with type `observed` to this symbol. + * + * Some symbols (e.g. [[`LocalVarSymbol`]]) can be assigned multiple times with values of differing types. The + * first observation of an otherwise-unset symbol seeds the erased type directly; every subsequent observation + * (and any observation joining an annotation-populated type) is folded in with [[`ErasedType.join`]]. Under that + * join an unknown (`N`) assignment is "poison": it widens the symbol to the unknown top type and is absorbing, + * while two differing known types widen to the definitive top type [[`ErasedType.ObjectRef`]]. + */ + def observeErasedTypeAssign(observed: Opt[ErasedType]): Unit = + if !erasedTypeObserved && erasedType.isEmpty then + erasedType = observed + else + erasedType = ErasedType.join(erasedType, observed) + erasedTypeObserved = true + +extension (lit: Literal) + def erasedType: ErasedType = lit match + case Tree.UnitLit(_) => ErasedType.Primitive(PrimitiveType.Unit) + case Tree.IntLit(_) => ErasedType.Primitive(PrimitiveType.Int) + case Tree.DecLit(_) => ErasedType.Primitive(PrimitiveType.Num) + case Tree.StrLit(_) => ErasedType.Primitive(PrimitiveType.Str) + case Tree.BoolLit(_) => ErasedType.Primitive(PrimitiveType.Bool) + sealed trait TrivialResult extends Result -sealed abstract class Result extends AutoLocated: +sealed abstract class Result extends AutoLocated, HasErasedType: // // * Used for debugging locations: // sealed abstract class Result extends AutoLocated with ProductWithExtraInfo: // def extraInfo: Str = toLoc.toString @@ -944,6 +1057,29 @@ sealed abstract class Result extends AutoLocated: case Value.Lit(lit) => 0 case DynSelect(qual, fld, arrayIdx) => qual.size + fld.size + lazy val erasedType: Opt[ErasedType] = this match + case Tuple(_, _) => S(ErasedType.Primitive(PrimitiveType.Array)) + case Record(_, _) => S(ErasedType.ObjectRef) + case Instantiate(_, cls, _) => cls.targetSymbol.flatMap(_.asClsOrMod).map(ErasedType.AnyRef(rsc = false, _)) + case Value.SimpleRef(sym) => sym.erasedType + case Value.MemberRef(_, disamb: (ClassSymbol | ModuleOrObjectSymbol | TypeAliasSymbol)) => disamb.erasedType + case Value.This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => clsOrMod.erasedType + case Value.Lit(lit) => S(lit.erasedType) + case Call(Value.SimpleRef(bs: BuiltinSymbol), argss) => + bs.resultErasedType(argss.head.map(_.value.erasedType)) + case Call(fun, _) => fun.targetSymbol match + case S(ts: TermSymbol) => ts.erasedType match + case S(ErasedType.FuncRef(_, ret)) => ret + case _ => N + case _ => N + // * A resolved selection has the type of the member it refers to (e.g. `this.field`); an + // * unresolved selection (dynamic field access) stays unknown. + case sel @ Select(_, _) => sel.symbol match + case S(ts: TermSymbol) => ts.erasedType + case S(d: (ClassSymbol | ModuleOrObjectSymbol | TypeAliasSymbol)) => d.erasedType + case _ => N + case _ => N + /* mayRaiseEffects indicates whether this call may raise effect (algebraic effect), * regardless of whether the check for effect is inserted or not. * Note that the check for effect is inserted during HandlerLowering and setting this to true @@ -1085,7 +1221,7 @@ object Value: case SimpleRef(l) => l case MemberRef(bms, disamb) => bms case This(sym) => sym - + @deprecated("Use Value.SimpleRef, Value.MemberRef, or Value.This instead.") object Ref: def apply(l: ValueSymbol | NoSymbol.type, disamb: Opt[DefinitionSymbol[?]]): Value.RefLike = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala index 92ca97effd..54675165f5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockSimplifier.scala @@ -1538,7 +1538,7 @@ class BlockSimplifier def go(acc: Block => Block, args: List[(VarSymbol, Result)], mapping: Map[Symbol, Symbol]): Block = args match case Nil => - val resSym = TempSymbol(N, "inlinedVal") + val resSym = TempSymbol(N, erasedType = N, "inlinedVal") val copier = Copier(resSym, mapping) val newBlk = copier.applyBlock(blk) if extraArgss.isEmpty then @@ -1550,7 +1550,7 @@ class BlockSimplifier annotations = c.metadata.annotations.filterNot(_ == Annot.TailCall), )))))) case (sym, value) :: argRest => - val newSym = VarSymbol(sym.id) + val newSym = VarSymbol(sym.id, erasedType = N) go(acc.assignScoped(newSym, value), argRest, mapping + (sym -> newSym)) go(blockBuilder, matchedArgs, Map.empty) case _ => super.applyResult(r)(k) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala index 9462078022..c3e11d63bc 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala @@ -22,7 +22,7 @@ class BufferableTransform()(using Ctx, State, Raise): cls.bufferable.fold(super.applyDefn(defn)(k)): bufferable => val companionSym = ModuleOrObjectSymbol(DummyTypeDef(syntax.Mod), new Tree.Ident(cls.sym.nme)) val clsSizeSym = BlockMemberSymbol("size", Nil, false) - val clsSizeTermSym = TermSymbol(syntax.ImmutVal, S(companionSym), new Tree.Ident("size")) + val clsSizeTermSym = TermSymbol(syntax.ImmutVal, S(companionSym), new Tree.Ident("size"), erasedType = S(ErasedType.Primitive(PrimitiveType.Int))) val pubFieldMap: Map[Symbol, Symbol] = cls.publicFields.toMap val fields = cls.privateFields ++ cls.publicFields.map(_._2) val fieldMap: Map[Symbol, Int] = fields.zipWithIndex.toMap @@ -30,18 +30,18 @@ class BufferableTransform()(using Ctx, State, Raise): val allVars = params.flatMap(_.allParams).map(_.sym) val varMap = allVars .map: sym => - (sym, VarSymbol(sym.id)) + (sym, VarSymbol(sym.id, erasedType = N)) .toMap def mapParam(p: Param) = Param(p.flags, varMap(p.sym), p.sign, p.modulefulness) (params.map(pl => ParamList(pl.flags, pl.params.map(mapParam), pl.restParam.map(mapParam))), varMap.toMap) def mkFieldReplacer(buf: VarSymbol, baseIdx: VarSymbol, symMap: Map[SimpleSymbol, SimpleSymbol]) = def getOffset(off: Int)(k: Path => Block): Block = - val idxSymbol = new TempSymbol(N, "idx") + val idxSymbol = new TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "idx") Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(CallMetadata.defaultMlsFun), k(DynSelect(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true)))) def assignToOffset(off: Int, r: Result, rst: Block) = - val idxSymbol = new TempSymbol(N, "idx") + val idxSymbol = new TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "idx") Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(CallMetadata.defaultMlsFun), AssignDynField(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true, r, applyBlock(rst)))) new BlockTransformer(SymbolSubst.Id): @@ -73,11 +73,11 @@ class BufferableTransform()(using Ctx, State, Raise): k(res) case _ => super.applyPath(p)(k) def transformFunDefn(f: FunDefn, isCtor: Bool): FunDefn = - val buf = VarSymbol(new Tree.Ident("buf")) - val idx = VarSymbol(new Tree.Ident("idx")) + val buf = VarSymbol(new Tree.Ident("buf"), erasedType = N) + val idx = VarSymbol(new Tree.Ident("idx"), erasedType = S(ErasedType.Primitive(PrimitiveType.Int))) val (newParams, symMap) = mkSymbolReplacer(f.params) val blk = mkFieldReplacer(buf, idx, symMap).applyBlock(f.body) - FunDefn(f.owner, f.sym, TermSymbol(f.dSym.k, f.dSym.owner, f.dSym.id), PlainParamList( + FunDefn(f.owner, f.sym, TermSymbol(f.dSym.k, f.dSym.owner, f.dSym.id, erasedType = N), PlainParamList( Param(FldFlags.empty, buf, N, Modulefulness.none) :: Param(FldFlags.empty, idx, N, Modulefulness.none) :: Nil) :: newParams, if isCtor then Begin(blk, Return(idx.asSimpleRef)) else blk)(configOverride = f.configOverride, annotations = f.annotations) val fakeCtor = transformFunDefn(FunDefn.withFreshSymbol( diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala index 3f8d2158e0..1aeaede47e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala @@ -166,7 +166,7 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise): val name = instId.mkFunName + s"$$${f.nme}" f -> ( new BlockMemberSymbol(name, Nil, true), - new TermSymbol(Fun, N, Tree.Ident(name))) + new TermSymbol(Fun, N, Tree.Ident(name), erasedType = N)) .toMap) end mkNewPolyFnSyms @@ -335,12 +335,12 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise): case ParamList(flags, params, restParam) => val params2 = params.map: case p => - val newSym = new VarSymbol(Tree.Ident(p.sym.name)) + val newSym = new VarSymbol(Tree.Ident(p.sym.name), erasedType = N) refreshParamMap(p.sym) = newSym Param(p.flags, newSym, p.sign, p.modulefulness) val rest2 = restParam.map: case p => - val newSym = new VarSymbol(Tree.Ident(p.sym.name)) + val newSym = new VarSymbol(Tree.Ident(p.sym.name), erasedType = N) refreshParamMap(p.sym) = newSym Param(p.flags, newSym, p.sign, p.modulefulness) ParamList(flags, params2, rest2) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala index 7ecf86765b..cb53707460 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala @@ -173,7 +173,7 @@ class EtaExpansionRewrite(val etaExpansionSolver: EtaExpansionSolver)(using Rais targetShape.drop(existingShape.size).zipWithIndex.map: case (count, idx) => val params = (0 until count).toList.map: i => - Param.simple(new VarSymbol(new Tree.Ident(s"eta$$$idx$$$i"))) + Param.simple(new VarSymbol(new Tree.Ident(s"eta$$$idx$$$i"), erasedType = N)) EtaParamList( ParamList(ParamListFlags.empty, params, N), params.map(p => Arg(N, p.sym.asSimpleRef)), @@ -195,7 +195,7 @@ class EtaExpansionRewrite(val etaExpansionSolver: EtaExpansionSolver)(using Rais Return( Call(fun, (argss ++ activeEtaArgss).ne_!)(c.metadata)) case _ => - val tmp = TempSymbol(N, "eta$res") + val tmp = TempSymbol(N, erasedType = N, "eta$res") Scoped( Set.single(tmp), Assign(tmp, res2, Return(etaCall(tmp.asPath).withLocOf(res2)))) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala index 95a7246285..44d9b5f90f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/FirstClassFunctionTransformer.scala @@ -59,7 +59,7 @@ class FirstClassFunctionTransformer private def etaExpandPath(p: Path, params: ParamList)(k: Path => Block): Block = val clsDef = generateFCFunctionClass(p, params) - val tmp = new TempSymbol(None) + val tmp = new TempSymbol(None, erasedType = S(ErasedType.AnyRef(rsc = false, clsDef.isym.asClsOrMod.get))) val cls = clsDef.sym.asMemberRef(clsDef.isym) Scoped(Set(clsDef.sym, tmp), Define(clsDef, Assign(tmp, Instantiate(false, cls, Nil :: Nil)(InstantiateMetadata.empty), k(tmp.asSimpleRef)))) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index 0536f02b4e..66a6b859df 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -119,7 +119,7 @@ class HandlerPaths(using Elaborator.State): class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, Elaborator.State, Elaborator.Ctx, Config): - private def freshTmp(dbgNme: Str = "tmp") = new TempSymbol(N, dbgNme) + private def freshTmp(erasedType: Opt[ErasedType], dbgNme: Str = "tmp") = new TempSymbol(N, erasedType, dbgNme) private def freshLabel(nme: Str) = new LabelSymbol(N, nme) private def rtThrowMsg(msg: Str) = Throw( @@ -139,7 +139,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case _ => N object StateTransition: - private val transitionSymbol = freshTmp("transition") + private val transitionSymbol = freshTmp(erasedType = N, "transition") def apply(uid: StateId) = Return(PureCall(transitionSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid))))) def unapply(blk: Block) = blk match @@ -148,7 +148,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case _ => N object Unwind: - private val unwindSymbol = freshTmp("unwind") + private val unwindSymbol = freshTmp(erasedType = N, "unwind") def apply(uid: StateId, loc: Value) = Return(PureCall(unwindSymbol.asSimpleRef, List(Value.Lit(Tree.IntLit(uid)), loc))) def unapply(blk: Block) = blk match @@ -538,7 +538,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, .flatMap: (sym, idx) => List(intLit(idx), Value.Lit(Tree.StrLit(sym.nme))) .map(_.asArg) - val debugInfoSym = freshTmp(s"$debugNme$$debugInfo") + val debugInfoSym = freshTmp(erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), s"$debugNme$$debugInfo") // TODO: properly support spread argument by calculating the correct length. val rtArgLists = intLit(fun.params.length) :: fun.params.flatMap: pl => intLit(pl.params.length) :: pl.params.map(p => p.sym.asSimpleRef) @@ -610,8 +610,8 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, return b val vars = if opt.debug then ctx.resumeInfo.currentLocals else computeRestoreList(parts) - val pcVar = freshTmp("pc") - val curDepth = freshTmp("curDepth") + val pcVar = freshTmp(erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "pc") + val curDepth = freshTmp(erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "curDepth") val mainLoopLbl = freshLabel("main") val edges = computeEdges(parts) @@ -685,7 +685,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case (acc, f) => f(acc) Label(mainLoopLbl, true, matches, End()) - val getSavedTmp = freshTmp("saveOffset") + val getSavedTmp = freshTmp(erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "saveOffset") def getSaved(off: BigInt): (Block => Block, Path) = if off == 0 then return (id, DynSelect(paths.runtimePath.selSN("resumeArr"), paths.runtimePath.selSN("resumeIdx"), true)) @@ -771,7 +771,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, override def applyResult(r: Result)(k: Result => Block) = r match case r @ EffectfulResult() => // Fallback case, this may lead to unnecessary assignments if it is assign-like - val l = freshTmp() + val l = freshTmp(erasedType = N) Scoped(Set(l), effectCheck(l, r, k(l.asSimpleRef))) case _ => super.applyResult(r)(k) topLevelPostTransform.applyBlock(b) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index 1a582cf06a..66a3a42f8f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -395,7 +395,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): else val newSym = closureMap.get(d) match case None => - val newSym = TempSymbol(N, d.nme + "$here") + val newSym = TempSymbol(N, erasedType = N, d.nme + "$here") extraLocals.add(newSym) syms.addOne(d -> newSym) // add to `syms`: this closure will be initialized in `applyBlock` closureMap.addOne(d -> newSym) // add to `closureMap`: `newSym` refers to the closure and can be used later @@ -415,7 +415,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): case cls: LiftedClass if !cls.isTrivial => val newSym = closureMap.get(d) match case None => - val newSym = TempSymbol(N, d.nme + "$here") + val newSym = TempSymbol(N, erasedType = N, d.nme + "$here") extraLocals.add(newSym) syms.addOne(d -> newSym) closureMap.addOne(d -> newSym) @@ -568,9 +568,12 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val nme = sym.nme + "$" + id val ident = new Tree.Ident(nme) - val varSym = VarSymbol(ident) + val capturedType = sym match + case s: HasErasedType => s.erasedType + case _ => N + val varSym = VarSymbol(ident, erasedType = capturedType) val fldSym = BlockMemberSymbol(nme, Nil) - val tSym = TermSymbol(syntax.MutVal, S(clsSym), ident) + val tSym = TermSymbol(syntax.MutVal, S(clsSym), ident, erasedType = capturedType) val p = Param(FldFlags.empty.copy(isVal = true), varSym, N, Modulefulness.none) varSym.decl = S(p) // * Currently this is only accessed to create the class' toString method @@ -679,9 +682,14 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): protected final def addExtraSyms(b: Block, captureSym: LocalVarSymbol, objSyms: Iterable[ScopedSymbol]): Block = if hasCapture then + val inst = instantiateCapture + // * The capture symbol holds an instance of the capture class, so it takes that type. + captureSym match + case s: HasOnceMutableErasedType => s.erasedType = inst.erasedType + case _ => Scoped( objSyms.toSet + captureSym, - Assign(captureSym, instantiateCapture, b) + Assign(captureSym, inst, b) ) else Scoped(objSyms.toSet, b) @@ -853,11 +861,11 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): * A rewritten scope with a generic VarSymbol capture symbol. */ sealed trait GenericRewrittenScope[T] extends RewrittenScope[T]: - lazy val captureSym = VarSymbol(Tree.Ident(obj.nme + "$cap")) + lazy val captureSym = VarSymbol(Tree.Ident(obj.nme + "$cap"), erasedType = N) override lazy val capturePath = captureSym.asSimpleRef protected val liftedObjsOrdered: List[InnerSymbol] = node.liftedObjSyms.toList.sortBy(_.uid) protected val liftedObjsSyms: Map[InnerSymbol, VarSymbol] = liftedObjsOrdered.map: s => - s -> VarSymbol(Tree.Ident(s.nme + "$")) + s -> VarSymbol(Tree.Ident(s.nme + "$"), erasedType = N) .toMap override lazy val liftedObjsMap: Map[InnerSymbol, LocalPath] = liftedObjsSyms.map: case k -> v => k -> v.asLocalPath @@ -868,11 +876,11 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): * A rewritten scope with a TermSymbol capture symbol. */ sealed trait ClsLikeRewrittenScope[T](sym: InnerSymbol) extends RewrittenScope[T]: - lazy val captureSym = TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath = Select(Value.This(sym), captureSym.id)(S(captureSym))(false) + lazy val captureSym = TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(obj.nme + "$cap"), erasedType = N) + override lazy val capturePath = Select(sym.asThis, captureSym.id)(S(captureSym))(false) protected val liftedObjsOrdered: List[InnerSymbol] = node.liftedObjSyms.toList.sortBy(_.uid) protected val liftedObjsSyms: Map[InnerSymbol, TermSymbol] = liftedObjsOrdered.map: s => - s -> TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(s.nme + "$")) + s -> TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(s.nme + "$"), erasedType = N) .toMap override lazy val liftedObjsMap: Map[InnerSymbol, LocalPath] = liftedObjsSyms.map: case k -> v => k -> LocalPath.privateSelfField(v) @@ -887,11 +895,15 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val (liftedMtds, extras) = mtds.map(liftNestedScopes).unzip(using l => (l.liftedDefn, l.extraDefns)) LifterResult(liftedMtds, extras.flatten) protected final def initCaptureField(b: Block): Block = - if hasCapture then AssignField(sym.asThis, captureSym.id, instantiateCapture, b)(S(captureSym)) + if hasCapture then + val inst = instantiateCapture + // * The capture field holds an instance of the capture class, so it takes that type. + captureSym.erasedType = inst.erasedType + AssignField(sym.asThis, captureSym.id, inst, b)(S(captureSym)) else b // some helpers - private def dupParam(p: Param): Param = p.copy(sym = VarSymbol(Tree.Ident(p.sym.nme))) + private def dupParam(p: Param): Param = p.copy(sym = VarSymbol(Tree.Ident(p.sym.nme), erasedType = N)) private def dupParams(plist: List[Param]): List[Param] = plist.map(dupParam) private def dupParamList(plist: ParamList): ParamList = plist.copy(params = dupParams(plist.params), restParam = plist.restParam.map(dupParam)) @@ -949,8 +961,8 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): extends RewrittenScope[ClsLikeDefn](obj) with ClsLikeRewrittenScope[ClsLikeDefn](obj.cls.isym): - private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath: Path = Select(Value.This(obj.cls.isym), captureSym.id)(S(captureSym))(false) + private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap"), erasedType = N) + override lazy val capturePath: Path = Select(obj.cls.isym.asThis, captureSym.id)(S(captureSym))(false) override def rewriteImpl: LifterResult[ClsLikeDefn] = val liftedSuper = obj.cls.parentPath.flatMap: @@ -983,8 +995,8 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): extends RewrittenScope[ClsLikeBody](obj) with ClsLikeRewrittenScope[ClsLikeBody](obj.clsBody.isym): - private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.clsBody.isym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath: Path = Select(Value.This(obj.clsBody.isym), captureSym.id)(S(captureSym))(false) + private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.clsBody.isym), Tree.Ident(obj.nme + "$cap"), erasedType = N) + override lazy val capturePath: Path = Select(obj.clsBody.isym.asThis, captureSym.id)(S(captureSym))(false) override def rewriteImpl: LifterResult[ClsLikeBody] = val rewriterCtor = new BlockRewriter(N) @@ -1005,15 +1017,18 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): class LiftedFunc(override val obj: ScopedObject.Func)(using ctx: LifterCtxNew) extends LiftedScope[FunDefn](obj) with GenericRewrittenScope[FunDefn]: private val passedSymsMap_ : Map[ValueSymbol, VarSymbol] = passedSymsOrdered.map: s => - s -> VarSymbol(Tree.Ident(s.nme)) + val erasedType = s match + case h: HasErasedType => h.erasedType + case _ => N + s -> VarSymbol(Tree.Ident(s.nme), erasedType) .toMap private val capSymsMap_ : Map[ScopedInfo, VarSymbol] = capturesOrdered.map: i => val nme = data.getNode(i).obj.nme - i -> VarSymbol(Tree.Ident(nme + "$cap")) + i -> VarSymbol(Tree.Ident(nme + "$cap"), erasedType = N) .toMap private val defnSymsMap_ : Map[DefinitionSymbol[?], VarSymbol] = reqDefnsOrdered.sortBy(_.uid).map: i => val nme = data.getNode(i).obj.nme - i -> VarSymbol(Tree.Ident(nme + "$")) + i -> VarSymbol(Tree.Ident(nme + "$"), erasedType = N) .toMap override protected val passedSymsMap = passedSymsMap_.view.mapValues(_.asLocalPath).toMap @@ -1034,7 +1049,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val (mainSym, mainDsym) = (fun.sym, fun.dSym) val auxSym = BlockMemberSymbol(fun.sym.nme + "$", Nil, fun.sym.nameIsMeaningful) - val auxDsym = TermSymbol.fromFunBms(auxSym, fun.owner) + val auxDsym = TermSymbol.fromFunBms(auxSym, fun.owner, erasedType = N) // Definition with the auxiliary parameters merged into the first parameter list. private def mkFlattenedDefn: LifterResult[FunDefn] = @@ -1107,29 +1122,29 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): extends LiftedScope[ClsLikeDefn](obj) with ClsLikeRewrittenScope[ClsLikeDefn](obj.cls.isym): - private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap")) - override lazy val capturePath: Path = Select(Value.This(obj.cls.isym), captureSym.id)(S(captureSym))(false) + private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap"), erasedType = N) + override lazy val capturePath: Path = Select(obj.cls.isym.asThis, captureSym.id)(S(captureSym))(false) private val passedSymsMap_ : Map[ValueSymbol, (vs: VarSymbol, ts: TermSymbol)] = passedSymsOrdered.map: s => s -> ( - VarSymbol(Tree.Ident(s.nme)), - TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(s.nme)) + VarSymbol(Tree.Ident(s.nme), erasedType = N), + TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(s.nme), erasedType = N) ) .toMap private val capSymsMap_ : Map[ScopedInfo, (vs: VarSymbol, ts: TermSymbol)] = capturesOrdered.map: i => val nme = data.getNode(i).obj.nme + "$cap" i -> ( - VarSymbol(Tree.Ident(nme)), - TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(nme)) + VarSymbol(Tree.Ident(nme), erasedType = N), + TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(nme), erasedType = N) ) .toMap private val defnSymsMap_ : Map[DefinitionSymbol[?], (vs: VarSymbol, ts: TermSymbol)] = reqDefnsOrdered.map: i => i -> ( - VarSymbol(Tree.Ident(i.nme + "$")), - TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(i.nme + "$")) + VarSymbol(Tree.Ident(i.nme + "$"), erasedType = N), + TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(i.nme + "$"), erasedType = N) ) .toMap @@ -1160,12 +1175,12 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config): val cls = obj.cls val flattenedSym = BlockMemberSymbol(obj.cls.sym.nme + "$", Nil, true) - val flattenedDSym = TermSymbol.fromFunBms(flattenedSym, N) + val flattenedDSym = TermSymbol.fromFunBms(flattenedSym, N, erasedType = N) // Contains *all* parameters, and applies them all at once in a single `Instantiate` def mkFlattenedDefn: FunDefn = // Symbols for the aux parameter list - val auxSyms = auxParams.map(p => VarSymbol(Tree.Ident(p.sym.nme))) + val auxSyms = auxParams.map(p => VarSymbol(Tree.Ident(p.sym.nme), erasedType = N)) val auxParamListLocal = PlainParamList(auxSyms.map(Param.simple(_))) val dupedClsAuxParams = cls.auxParams.map(dupParamList(_)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index c495bd97de..0d7a9967a8 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -43,8 +43,8 @@ class LoweringCtx( val map = initMap def collectScopedSym(s: ScopedSymbol) = definedSymsDuringLowering.add(s) def collectScopedSyms(s: ScopedSymbol*) = definedSymsDuringLowering.addAll(s) - def registerTempSymbol(trm: Option[Term], dbgNme: Str = "tmp")(using State) = - val tmp = new TempSymbol(trm, dbgNme) + def registerTempSymbol(trm: Option[Term], erasedType: Opt[ErasedType], dbgNme: Str = "tmp")(using State) = + val tmp = new TempSymbol(trm, erasedType, dbgNme) definedSymsDuringLowering.add(tmp) tmp def getCollectedSym: collection.Set[ScopedSymbol] = definedSymsDuringLowering @@ -256,6 +256,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): blockImpl(stats, res)))(using LoweringCtx.nestFunc) case syntax.Fun => val (paramLists, bodyBlock) = setupFunctionOrByNameDef(td.params, bod, S(td.sym.nme)) + populateFunDefnType(td.tsym, paramLists, td.sign, bodyBlock) val cfgOverride = td.extraAnnotations.collectFirst: case Annot.Config(modify) => modify(config) Define(FunDefn(td.owner, td.sym, td.tsym, paramLists, bodyBlock)(cfgOverride, td.annotations), @@ -296,6 +297,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ) case _ => _defn reportAnnotations(defn, defn.extraAnnotations) + (defn.paramsOpt.iterator ++ defn.auxParams.iterator).foreach: pl => + pl.params.foreach(populateClassParamErasedType) + pl.restParam.foreach(populateRestParamErasedType) val bufferableAnnots = defn.annotations.flatMap: case Annot.Trm(trm: SynthSel) => if trm.sym.contains(ctx.builtins.annotations.buffered) then @@ -440,8 +444,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): acc.reverse match case Nil => lowerRemainingCalls(fr, args, remainingArgss, annotations, loc)(k) case acc: NELs[Ls[Arg]] => - val tmp = loweringCtx.registerTempSymbol(N, "baseCall") val call = Call(fr, acc)(CallMetadata(isMlsFun, mayRaiseEffects, Nil)).withLoc(loc) + val tmp = loweringCtx.registerTempSymbol(N, erasedType = call.erasedType, "baseCall") Assign(tmp, call, lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, annotations, loc)(k)) case (_ :: _, Nil) => k(Call(fr, acc.reverse.ne_!)(CallMetadata(isMlsFun, mayRaiseEffects, annotations)).withLoc(loc)) @@ -460,7 +464,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): remainingArgss match case Nil => k(call) case args :: remainingArgss => - val tmp = loweringCtx.registerTempSymbol(N, "callPrefix") + val tmp = loweringCtx.registerTempSymbol(N, erasedType = call.erasedType, "callPrefix") Assign(tmp, call, lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, annotations, loc)(k)) @@ -484,8 +488,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case (Nil, Nil) => k(buildInstantiate(acc.reverse)) case (Nil, args :: remainingArgss) => - val tmp = loweringCtx.registerTempSymbol(N, "baseInst") - Assign(tmp, buildInstantiate(acc.reverse), + val inst = buildInstantiate(acc.reverse) + val tmp = loweringCtx.registerTempSymbol(N, erasedType = inst.erasedType, "baseInst") + Assign(tmp, inst, lowerRemainingCalls(tmp.asSimpleRef, args, remainingArgss, annotations, N)(k)) case (remainingParamss, Nil) => // * Eta-expand missing argument lists by creating lambdas for each remaining param list. @@ -494,7 +499,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): remainingParamss match case Nil => buildInstantiate(accArgss) case ps :: rest => - val freshSyms = ps.params.map(p => new VarSymbol(new Tree.Ident(p.sym.nme))) + val freshSyms = ps.params.map(p => new VarSymbol(new Tree.Ident(p.sym.nme), erasedType = N)) softTODO(ps.restParam.isEmpty, "Eta expanding rest parameters in constructor definitions is not yet supported") val freshParams = (ps.params zip freshSyms).map((p, s) => Param(p.flags, s, N, p.modulefulness)) val freshParamList = ParamList(ps.flags, freshParams, N) @@ -518,8 +523,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Nil => k(buildInstantiate(as :: Nil)) case remainingArgss => - val tmp = loweringCtx.registerTempSymbol(N, "baseInst") - Assign(tmp, buildInstantiate(as :: Nil), + val inst = buildInstantiate(as :: Nil) + val tmp = loweringCtx.registerTempSymbol(N, erasedType = inst.erasedType, "baseInst") + Assign(tmp, inst, lowerRemainingCalls(tmp.asSimpleRef, remainingArgss.head, remainingArgss.tail, annotations, N)(k)) else zipArgs(ctorParamLists, args, Nil) @@ -561,9 +567,19 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): sym.owner match case S(owner) => AssignField(owner.asThis, sym.id, rhs, rest)(S(sym)) case N => nope - case sym: LocalVarSymbol => Assign(sym, rhs, rest) + case sym: LocalVarSymbol => + observeLocalErasedType(sym, rhs) + Assign(sym, rhs, rest) case sym => nope + /** Observes an assignment of `rhs` to `sym`, populating or updating its erased type where applicable. + * + * See [[`HasManyMutableErasedType.observeErasedTypeAssign`]]. + */ + private def observeLocalErasedType(sym: LocalVarSymbol, rhs: Result): Unit = sym match + case sym: HasManyMutableErasedType => sym.observeErasedTypeAssign(rhs.erasedType) + case _ => + private def defineSymbol(sym: Symbol, rhs: Result, rest: Block)(using LoweringCtx): Block = sym match case sym: TermSymbol => @@ -571,6 +587,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case S(owner) => AssignField(owner.asThis, sym.id, rhs, rest)(S(sym)) case N => lastWords(s"tried to define top-level symbol ${sym.showDbg} in a local scope") case sym: LocalVarSymbol => + observeLocalErasedType(sym, rhs) Assign(sym, rhs, rest) case sym => lastWords(s"tried to define non-variable symbol ${sym.showDbg}") @@ -609,8 +626,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): if sym.binary then val t1 = new Tree.Ident("arg1") val t2 = new Tree.Ident("arg2") - val p1 = Param(FldFlags.empty, VarSymbol(t1), N, Modulefulness.none) - val p2 = Param(FldFlags.empty, VarSymbol(t2), N, Modulefulness.none) + val p1 = Param(FldFlags.empty, VarSymbol(t1, erasedType = N), N, Modulefulness.none) + val p2 = Param(FldFlags.empty, VarSymbol(t2, erasedType = N), N, Modulefulness.none) val ps = PlainParamList(p1 :: p2 :: Nil) val bod = st.App(ref, st.Tup(List(st.Ref(p1.sym)(t1, 666, N).resolve, st.Ref(p2.sym)(t2, 666, N).resolve)) (Tree.Tup(Nil // FIXME should not be required (using dummy value) @@ -625,7 +642,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): return k(Lambda(paramLists.head, bodyBlock)(Nil).withLocOf(ref)) if sym.unary then val t1 = new Tree.Ident("arg") - val p1 = Param(FldFlags.empty, VarSymbol(t1), N, Modulefulness.none) + val p1 = Param(FldFlags.empty, VarSymbol(t1, erasedType = N), N, Modulefulness.none) val ps = PlainParamList(p1 :: Nil) val bod = st.App(ref, st.Tup(List(st.Ref(p1.sym)(t1, 666, N).resolve)) (Tree.Tup(Nil // FIXME should not be required (using dummy value) @@ -718,8 +735,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): if !hasNonLocalContinueDispatch then term_nonTail(body)(r => Assign(result, r, Break(label))) else - val bodyResult = loweringCtx.registerTempSymbol(N, "labelBodyResult") - val isContinue = loweringCtx.registerTempSymbol(N, "labelContinueDispatch") + val bodyResult = loweringCtx.registerTempSymbol(N, erasedType = N, "labelBodyResult") + val isContinue = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "labelContinueDispatch") term_nonTail(body): r => Assign( bodyResult, @@ -795,7 +812,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): S(k(Value.Lit(negLit))), Unreachable("tail operation in branches"), ) else - val ts = loweringCtx.registerTempSymbol(N) + val ts = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool))) Match( ar1, (Case.Lit(posLit) -> term_nonTail(arg2)(Assign(ts, _, End()))) :: Nil, @@ -921,7 +938,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): S(Handler(td.sym, resumeSym, paramLists, bodyBlock)) }.collect{ case Some(v) => v } loweringCtx.collectScopedSym(lhs) - val resSym = loweringCtx.registerTempSymbol(S(t)) + val resSym = loweringCtx.registerTempSymbol(S(t), erasedType = N) subTerm(rhs): par => subTerms(as): asr => HandleBlock(lhs, resSym, par, asr, cls, handlers, @@ -1039,7 +1056,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): Define(clsDef, term_nonTail(if mut then Mut(inner) else inner)(k)) case Try(sub, finallyDo) => - val l = loweringCtx.registerTempSymbol(S(sub)) + val l = loweringCtx.registerTempSymbol(S(sub), erasedType = N) TryBlock( subTerm_nonTail(sub)(p => Assign(l, p, End())), subTerm_nonTail(finallyDo)(_ => End()), @@ -1116,7 +1133,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): def quoteSplit(split: Split, splitTmps: Map[SplitSymbol, TempSymbol])(k: Result => Block)(using LoweringCtx): Block = split match case Split.Cons(Branch(scrutinee, pattern, continuation), tail) => quote(scrutinee): r1 => - val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N) + val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N, erasedType = N) blockBuilder.assign(l1, r1) .chain(b => quotePattern(pattern)(r2 => Assign(l2, r2, b))) .chain(b => quoteSplit(continuation, splitTmps)(r3 => Assign(l3, r3, b))) @@ -1125,26 +1142,26 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): .rest(setupTerm("Cons", (l4 :: l5 :: Nil).map(s => s.asSimpleRef))(k)) case Split.Let(sym: LocalVarSymbol, term, tail) => setupSymbol(sym): r1 => loweringCtx.collectScopedSym(sym) - val l1, l2, l3 = loweringCtx.registerTempSymbol(N) + val l1, l2, l3 = loweringCtx.registerTempSymbol(N, erasedType = N) blockBuilder.assign(l1, r1) .chain(b => setupTerm("Ref", l1.asSimpleRef :: Nil)(r => Assign(sym, r, b))) .chain(b => quote(term)(r2 => Assign(l2, r2, b))) .chain(b => quoteSplit(tail, splitTmps)(r3 => Assign(l3, r3, b))) .rest(setupTerm("Let", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef))(k)) case Split.Else(default) => quote(default): r => - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l, r, setupTerm("Else", l.asSimpleRef :: Nil)(k)) case Split.End => setupTerm("End", Nil)(k) case Split.LetSplit(sym, tail) => - val tmp = loweringCtx.registerTempSymbol(N, sym.nme + "_splitTmp") + val tmp = loweringCtx.registerTempSymbol(N, erasedType = N, sym.nme + "_splitTmp") val splitTmps2 = splitTmps + (sym -> tmp) setupSymbol(tmp): r1 => - val l1, l2, l3 = loweringCtx.registerTempSymbol(N) + val l1, l2, l3 = loweringCtx.registerTempSymbol(N, erasedType = N) blockBuilder.assign(l1, r1) .chain(b => Assign(tmp, l1.asSimpleRef, b)) .chain(b => quoteSplit(sym.body, splitTmps2)(r2 => Assign(l2, r2, b))) .chain(b => quoteSplit(tail, splitTmps2)(r3 => Assign(l3, r3, b))) - .rest(setupTerm("LetSplit", (l1 :: l2 :: l3 :: Nil).map(s => s.asSimpleRef))(k)) + .rest(setupTerm("LetSplit", (l1 :: l2 :: l3 :: Nil).map(_.asSimpleRef))(k)) case Split.UseSplit(sym) => setupTerm("UseSplit", splitTmps(sym).asSimpleRef :: Nil)(k) lazy val setupFilename: Path = @@ -1155,7 +1172,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Lit(lit) => setupTerm("Lit", Value.Lit(lit) :: Nil)(k) case Ref(sym) if Elaborator.binaryOps.contains(sym.nme) => // builtin symbols - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) setupTerm("Builtin", Value.Lit(Tree.StrLit(sym.nme)) :: Nil)(k) case Resolved(Ref(sym), disamb) => sym match @@ -1166,7 +1183,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Ref(sym) => lastWords(s"Unexpected symbol kind ${sym.getClass.getSimpleName}: $sym") case SynthSel(Ref(sym: ModuleOrObjectSymbol), name) => // Module/object cross-stage references setupSymbol(sym): r1 => - val l1, l2 = loweringCtx.registerTempSymbol(N) + val l1, l2 = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l1, r1, setupTerm("CSRef", l1.asSimpleRef :: setupFilename :: Value.Lit(syntax.Tree.UnitLit(false)) :: Nil)(r2 => Assign(l2, r2, setupTerm("Sel", l2.asSimpleRef :: Value.Lit(syntax.Tree.StrLit(name.name)) :: Nil)(k)) )) @@ -1179,7 +1196,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): ) else (t.toLoc, sym.toLoc) match case (S(Loc(_, _, Origin(base, _, _))), S(Loc(_, _, Origin(filename, _, _)))) => setupSymbol(sym): r1 => - val l1, l2 = loweringCtx.registerTempSymbol(N) + val l1, l2 = loweringCtx.registerTempSymbol(N, erasedType = N) val basePath = base.up val targetPath = filename val relPath = targetPath.relativeTo(basePath).map(_.toString).getOrElse(targetPath.toString) @@ -1195,8 +1212,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case Lam(params, body) => def rec(ps: Ls[VarSymbol], ds: Ls[Path])(k: Result => Block)(using LoweringCtx): Block = ps match case Nil => quote(body): r => - val l = loweringCtx.registerTempSymbol(N) - val arr = loweringCtx.registerTempSymbol(N, "arr") + val l = loweringCtx.registerTempSymbol(N, erasedType = N) + val arr = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "arr") Assign( arr, Tuple(mut = false, ds.reverse.map(_.asArg)), @@ -1204,24 +1221,24 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): case sym :: rest => loweringCtx.collectScopedSym(sym) setupSymbol(sym): r => - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l, r, setupTerm("Ref", l.asSimpleRef :: Nil): r1 => Assign(sym, r1, rec(rest, l.asSimpleRef :: ds)(k))) rec(params.params.map(_.sym), Nil)(k) // TODO: restParam? case App(lhs, Tup(rhs)) => quote(lhs): r1 => def rec(es: Ls[Elem], xs: Ls[Path])(k: Result => Block): Block = es match case Nil => - val arrSym = loweringCtx.registerTempSymbol(N, "arr") + val arrSym = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "arr") Assign( arrSym, Tuple(mut = false, xs.reverse.map(_.asArg)), setupTerm("Tup", arrSym.asSimpleRef :: Nil): r2 => - val l1 = loweringCtx.registerTempSymbol(N) - val l2 = loweringCtx.registerTempSymbol(N) + val l1 = loweringCtx.registerTempSymbol(N, erasedType = N) + val l2 = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l1, r1, Assign(l2, r2, setupTerm("App", l1.asSimpleRef :: l2.asSimpleRef :: Nil)(k))) ) case Fld(_, t, _) :: rest => quote(t): r2 => - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l, r2, rec(rest, l.asSimpleRef :: xs)(k)) case Spd(eager, term) :: rest => fail: @@ -1235,8 +1252,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): require(sym2 is sym) loweringCtx.collectScopedSym(sym) setupSymbol(sym){r1 => - val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N) - val arrSym = loweringCtx.registerTempSymbol(N, "arr") + val l1, l2, l3, l4, l5 = loweringCtx.registerTempSymbol(N, erasedType = N) + val arrSym = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "arr") blockBuilder.assign(l1, r1) .chain(b => setupTerm("Ref", l1.asSimpleRef :: Nil)(r => Assign(sym, r, b))) .chain(b => quote(rhs)(r2 => Assign(l2, r2, b))) @@ -1247,7 +1264,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): .rest(setupTerm("Blk", arrSym.asSimpleRef :: l3.asSimpleRef :: Nil)(k)) } case IfLike(_, IfLikeForm.ReturningIf, split) => quoteSplit(split.getExpandedSplit, Map.empty): r => - val l = loweringCtx.registerTempSymbol(N) + val l = loweringCtx.registerTempSymbol(N, erasedType = N) Assign(l, r, setupTerm("IfLike", setupQuotedKeyword("If") :: l.asSimpleRef :: Nil)(k)) case Unquoted(body) => term(body)(k) case _ => fail: @@ -1263,6 +1280,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): .flatMap: td => td.body.map: bod => val (paramLists, bodyBlock) = setupFunctionDef(td.params, bod, S(td.sym.nme)) + populateFunDefnType(td.tsym, paramLists, td.sign, bodyBlock) reportAnnotations(td, td.extraAnnotations) val cfgOverride = td.extraAnnotations.collectFirst: case Annot.Config(modify) => modify(config) @@ -1317,7 +1335,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): if fsr.isEmpty then Begin(b, k(asr.reverse)) else - val rcdSym = loweringCtx.registerTempSymbol(N, "rcd") + val rcdSym = loweringCtx.registerTempSymbol(N, erasedType = S(ErasedType.ObjectRef), "rcd") Begin( b, Assign( @@ -1351,8 +1369,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val lamDef = FunDefn.withFreshSymbol(N, lamSym, params :: Nil, body)(configOverride = N, annotations = lam.annot) Define(lamDef, k(lamDef.asPath)) case r => - val l = loweringCtx.registerTempSymbol(N) - Assign(l, r, k(l |> Value.SimpleRef.apply)) + val l = loweringCtx.registerTempSymbol(N, erasedType = r.erasedType) + Assign(l, r, k(l.asSimpleRef)) def program(main: st.Blk, symbolsToPreserve: Set[BoundSymbol]): Program = @@ -1392,8 +1410,55 @@ class Lowering()(using Config, TL, Raise, State, Ctx, SymbolPrinter): val scopedSyms = loweringCtx.getCollectedSym.filterNot(syms) Scoped(scopedSyms, body) + /** Erases a type-annotated term to an [[`ErasedType`]]. */ + private def eraseSign(sign: Term): Opt[ErasedType] = + sign.symbol.flatMap(_.asClsOrMod).map(sym => ErasedType.fromClsLikeSymbol(sym, rsc = false)) + + /** Populates the [[`ErasedType`]] of a class parameter. */ + private def populateClassParamErasedType(p: Param): Unit = + p.sign.flatMap(eraseSign).foreach: et => + p.sym.populateErasedType(et) + p.fldSym.foreach: + case fld: TermSymbol => fld.populateErasedType(et) + case bms: BlockMemberSymbol => bms.tsym.foreach(_.populateErasedType(et)) + case _ => + + /** Populates the [[`ErasedType`]] of the `rest` parameter. */ + private def populateRestParamErasedType(p: Param): Unit = + p.sym.populateErasedType(ErasedType.Primitive(PrimitiveType.Array)) + + /** Infers the [[`ErasedType`]] of a function's return type by joining the erased types of all return values + * via [[`ErasedType.join`]]. + */ + private def inferReturn(body: Block): Opt[ErasedType] = + var rets: Ls[Opt[ErasedType]] = Nil + new BlockTraverserShallow: + override def applyBlock(b: Block): Unit = b match + case Return(res) => rets ::= res.erasedType + case _ => super.applyBlock(b) + .applyBlock(body) + rets match + case head :: tail => tail.foldLeft(head)(ErasedType.join) + case Nil => N + + /** Populates a function definition symbol's erased type with a [[`ErasedType.FuncRef`]] derived from its + * (already-populated) parameter symbols and return type. + * + * The return type comes from the explicit annotation when present, otherwise it is inferred from the body. + */ + // TODO(Derppening): Parameters of curried functions are currently flattened - Should we preserve the curried shape? + private def populateFunDefnType(tsym: TermSymbol, paramLists: Ls[ParamList], sign: Opt[Term], body: Block): Unit = + if tsym.erasedType.isEmpty then + val params = paramLists.flatMap(_.params).map(_.sym.erasedType) + val ret = sign.flatMap(eraseSign) orElse inferReturn(body) + tsym.erasedType = S(ErasedType.FuncRef(params, ret)) + def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str]) (using LoweringCtx): (List[ParamList], Block) = + paramLists.foreach: pl => + pl.params.foreach: p => + p.sign.flatMap(eraseSign).foreach(p.sym.populateErasedType) + pl.restParam.foreach(populateRestParamErasedType) val scopedBody = inScopedBlock(returnedTerm(bodyTerm)) (paramLists, scopedBody) @@ -1489,12 +1554,12 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) go(paramLists.reverse, bod) def setupFunctionBody(params: ParamList, bod: Term, name: Option[Str])(using LoweringCtx): Block = inScopedBlock: - val enterMsgSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogEnterMsg") - val prevIndentLvlSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogPrevIndent") - val resSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogRes") - val retMsgSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogRetMsg") - val psInspectedSyms = params.params.map(p => loweringCtx.registerTempSymbol(N, dbgNme = s"traceLogParam_${p.sym.nme}") -> p.sym) - val resInspectedSym = loweringCtx.registerTempSymbol(N, dbgNme = "traceLogResInspected") + val enterMsgSym = loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = "traceLogEnterMsg") + val prevIndentLvlSym = loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = "traceLogPrevIndent") + val resSym = loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = "traceLogRes") + val retMsgSym = loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = "traceLogRetMsg") + val psInspectedSyms = params.params.map(p => loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = s"traceLogParam_${p.sym.nme}") -> p.sym) + val resInspectedSym = loweringCtx.registerTempSymbol(N, erasedType = N, dbgNme = "traceLogResInspected") val psSymArgs = psInspectedSyms.zipWithIndex.foldRight[Ls[Arg]](Arg(N, Value.Lit(Tree.StrLit(")"))) :: Nil): @@ -1502,7 +1567,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) then Arg(N, s.asSimpleRef) :: acc else Arg(N, s.asSimpleRef) :: Arg(N, Value.Lit(Tree.StrLit(", "))) :: acc - val tmp1, tmp2, tmp3 = loweringCtx.registerTempSymbol(N) + val tmp1, tmp2, tmp3 = loweringCtx.registerTempSymbol(N, erasedType = N) assignStmts(psInspectedSyms.map: (pInspectedSym, pSym) => pInspectedSym -> pureCall(inspectFn, Arg(N, pSym.asSimpleRef) :: Nil) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index 8efbb7e062..758463eb41 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -30,6 +30,25 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case S(str) => str case N => summon[SymbolPrinter].printSymbol(l) + def print(et: ErasedType)(using Scope): Document = et match + case ErasedType.AnyRef(rsc, csym: ClassLikeSymbol) => doc"${if rsc then "rsc " else ""}${print(csym)}" + case ErasedType.AnyRef(rsc, NoSymbol) => doc"${if rsc then "rsc " else ""}Object" + case ErasedType.FuncRef(params, ret) => + doc"(${params.map(_.fold(doc"?")(print)).mkDocument(sep = doc", ")}) => ${ret.fold(doc"?")(print)}" + case ErasedType.Primitive(prim) => doc"${prim.toString}" + + /** Renders the type annotation for a symbol with an [[`ErasedType`]]. */ + def erasedTypeAnnot(x: HasErasedType)(using Scope): Document = + if !summon[ShowCfg].showErasedTypes then doc"" + else doc": ${x.erasedType.fold(doc"?")(print)}" + + /** Renders a function's return type, projected from the `FuncRef` carried by its definition symbol. */ + def returnTypeAnnot(dSym: TermSymbol)(using Scope): Document = + if !summon[ShowCfg].showErasedTypes then doc"" + else dSym.erasedType match + case S(ErasedType.FuncRef(_, ret)) => doc": ${ret.fold(doc"?")(print)}" + case _ => doc"" + def print(blk: Block)(using Scope): Document = blk match case Match(scrut, arms, dflt, rest) => def case_doc(c: Case) = c match @@ -69,7 +88,9 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case Scoped(syms, body) => scope.nest.givenIn: import hkmc2.given_Ordering_Uid // Not sure why needed... - val names = syms.toList.sortBy(_.uid).map(s => scope.allocateName(s)) + val names = syms.toList.sortBy(_.uid).map: + case sym: LocalVarSymbol => doc"${scope.allocateName(sym)}${erasedTypeAnnot(sym)}" + case bms: BlockMemberSymbol => doc"${scope.allocateName(bms)}" doc"let ${names.mkDocument(", ")}; # ${print(body)}" case End(msg) if msg.nonEmpty && config.commentGeneratedCode => doc"end /* ${msg} */" case End(_) => doc"end" @@ -99,8 +120,8 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): ctor: Block, ctorSym: Opt[TermSymbol], )(using Scope): Document = - val privFields = privateFields.map(x => doc"private val ${print(x)};").mkDocument(sep = doc" # ") - val pubFields = publicFields.map(x => doc"val ${print(x._1)};").mkDocument(sep = doc" # ") + val privFields = privateFields.map(x => doc"private val ${print(x)}${erasedTypeAnnot(x)};").mkDocument(sep = doc" # ") + val pubFields = publicFields.map(x => doc"val ${print(x._1)}${erasedTypeAnnot(x._2)};").mkDocument(sep = doc" # ") val docPrivFlds = if privateFields.isEmpty then doc"" else doc" # ${privFields}" val docPubFlds = if publicFields.isEmpty then doc"" else doc" # ${pubFields}" val docPreCtor = preCtor match @@ -129,8 +150,8 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): paramss .map: pl => val allParams = - pl.params.map(x => scope.allocateName(x.sym)) ++ - pl.restParam.map(x => "..." + scope.allocateName(x.sym)) + pl.params.map(x => doc"${scope.allocateName(x.sym)}${erasedTypeAnnot(x.sym)}") ++ + pl.restParam.map(x => doc"...${scope.allocateName(x.sym)}${erasedTypeAnnot(x.sym)}") allParams.mkDocument("(", ", ", ")") .mkDocument("") @@ -140,9 +161,9 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): val docParams = printParamLists(paramss) val docBody = print(body) val docStaged = if fun.isStaged then doc"staged " else doc"" - doc"${docStaged}fun ${print(dSym)}${docParams} ${bracedbk(docBody)}" + doc"${docStaged}fun ${print(dSym)}${docParams}${returnTypeAnnot(dSym)} ${bracedbk(docBody)}" case ValDefn(tsym, sym, rhs) => - doc"val ${print(tsym)} = ${print(rhs)}" + doc"val ${print(tsym)}${erasedTypeAnnot(tsym)} = ${print(rhs)}" case cls @ ClsLikeDefn(own, isym, sym, ctorSym, k, paramsOpt, auxParams, parentSym, methods, privateFields, publicFields, preCtor, ctor, mod, bufferable) => scope.nest.givenIn: @@ -197,8 +218,8 @@ class Printer(using Raise, ShowCfg, State, SymbolPrinter, Config): case Lambda(params, body) => scope.nest.givenIn: val allParams = - params.params.map(x => scope.allocateName(x.sym)) ++ - params.restParam.map(x => "..." + scope.allocateName(x.sym)) + params.params.map(x => doc"${scope.allocateName(x.sym)}${erasedTypeAnnot(x.sym)}") ++ + params.restParam.map(x => doc"...${scope.allocateName(x.sym)}${erasedTypeAnnot(x.sym)}") val docParams = allParams.mkDocument("(", ", ", ")") doc"$docParams => ${bracedbk(print(body))}" case Tuple(mut, elems) => diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala index f1e8c8e3eb..bda1ec0dfc 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/ReflectionInstrumenter.scala @@ -54,7 +54,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n def assign(res: Result, symName: Str = "tmp")(k: Path => Block): Block = // TODO: skip assignment if res: Path? - val sym = new TempSymbol(N, symName) + val sym = new TempSymbol(N, erasedType = res.erasedType, symName) Scoped(Set(sym), Assign(sym, res, k(sym.asSimpleRef))) def tuple(elems: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = @@ -374,7 +374,7 @@ class ReflectionInstrumenter(using State, Raise, Ctx) extends BlockTransformer(n val sym = f.owner.get.asThis.selSN(genSymName) // turn into fundefn - val dSym = TermSymbol(f.dSym.k, f.dSym.owner, Tree.Ident(f.sym.nme + "_instr")) + val dSym = TermSymbol(f.dSym.k, f.dSym.owner, Tree.Ident(f.sym.nme + "_instr"), erasedType = N) val argSyms = f.params.flatMap(_.params).map(_.sym) val newBody = Scoped(Set(argSyms*), transformFunDefn(f)(using new HashMap)(Return(_))) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala index a2198984b1..d88a6c4f21 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/SymbolRefresher.scala @@ -16,10 +16,10 @@ class SymbolRefresherWalker(mapping: MutMap[Symbol, Symbol])(using State) extend mapping(k) = v private def refreshTempSymbol(s: TempSymbol) = - assertUpdate(s, new TempSymbol(s.trm, s.nme)) + assertUpdate(s, new TempSymbol(s.trm, s.erasedType, s.nme)) private def refreshVarSymbol(s: VarSymbol) = - val ns = new VarSymbol(s.id) + val ns = new VarSymbol(s.id, s.erasedType) ns.sourceAliases = s.sourceAliases assertUpdate(s, ns) @@ -33,7 +33,7 @@ class SymbolRefresherWalker(mapping: MutMap[Symbol, Symbol])(using State) extend private def refreshTermSymbol(s: TermSymbol) = // Inner symbol (if present) must be traversed at this point. - val ns = new TermSymbol(s.k, s.owner.map(o => mapping.getOrElse(o, o).asInstanceOf[InnerSymbol]), s.id) + val ns = new TermSymbol(s.k, s.owner.map(o => mapping.getOrElse(o, o).asInstanceOf[InnerSymbol]), s.id, s.erasedType) ns.sourceAliases = s.sourceAliases assertUpdate(s, ns) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 433c886b02..1d17136d3c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -272,9 +272,9 @@ class TailRecOpt(using State, TL, Raise): else // Duplicate the params for the internal loop defn (see the doc at the // end of this function), but preserve the names. - syms.map(v => VarSymbol(Tree.Ident(v.id.name))) + syms.map(v => VarSymbol(Tree.Ident(v.id.name), erasedType = v.erasedType)) else - for i <- 0 until maxParamLen yield VarSymbol(Tree.Ident("param" + i)) + for i <- 0 until maxParamLen yield VarSymbol(Tree.Ident("param" + i), erasedType = N) .toList val paramSymsArr = ArrayBuffer.from(paramSyms) // Function -> param -> param symbol in the rewritten function @@ -294,9 +294,9 @@ class TailRecOpt(using State, TL, Raise): else BlockMemberSymbol(funs.iterator.map(_.sym.nme).mkString("_"), Nil, true) val dSym = if funsLen === 1 then funs.head.dSym - else TermSymbol(syntax.Fun, owner, Tree.Ident(bms.nme)) + else TermSymbol(syntax.Fun, owner, Tree.Ident(bms.nme), erasedType = N) val loopSym = LabelSymbol(N, "loopLabel") - val curIdSym = VarSymbol(Tree.Ident("id")) + val curIdSym = VarSymbol(Tree.Ident("id"), erasedType = S(ErasedType.Primitive(PrimitiveType.Int))) class FunRewriter(f: FunDefn) extends BlockTransformerShallow(SymbolSubst.Id): val params = getParamSyms(f) @@ -331,7 +331,7 @@ class TailRecOpt(using State, TL, Raise): .toSet val copiedParamSyms = copiedParams.map: x => - x -> VarSymbol(x.id) + x -> VarSymbol(x.id, erasedType = N) .toMap val subst = new SymbolSubst: @@ -366,7 +366,7 @@ class TailRecOpt(using State, TL, Raise): // We should thus assign the params to temporary symbols // if they are needed for a subsequent assignment. var assignedSyms: Map[VarSymbol, Lazy[TempSymbol]] = paramSyms.map: sym => - sym -> Lazy(TempSymbol(N, sym.nme + "_tmp")) // Use `Lazy` to avoid generating useless symbols + sym -> Lazy(TempSymbol(N, erasedType = sym.erasedType, sym.nme + "_tmp")) // Use `Lazy` to avoid generating useless symbols .toMap var requiredTmps: Set[(VarSymbol, TempSymbol)] = Set.empty @@ -430,7 +430,7 @@ class TailRecOpt(using State, TL, Raise): val paramList = ogParamList.params val restParam = ogParamList.restParam - val tupleSym = TempSymbol(N, "argList") + val tupleSym = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "argList") // We can safely remove all of the symbols from this parameter list from `assignedSyms` at this stage, // because the RHS of every parameter will be computed when spreading them in the tuple, which happens @@ -446,7 +446,7 @@ class TailRecOpt(using State, TL, Raise): // If the rest param exists, append a slice val (initialBlk: (Block => Block), pathList: List[Path]) = if restParam.isDefined then - val sliceResSym = TempSymbol(N, "sliceRes") + val sliceResSym = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "sliceRes") // runtime.Tuple.slice(tupleSym, paramList.length, 0) val sliceRes = Call( State.runtimeSymbol.asSimpleRef @@ -525,7 +525,7 @@ class TailRecOpt(using State, TL, Raise): val params = paramSyms.map(Param.simple(_)) val loopBms = BlockMemberSymbol(bms.nme + "$tailrec", Nil, true) - val loopDSym = TermSymbol(syntax.Fun, owner, Tree.Ident(loopBms.nme)) + val loopDSym = TermSymbol(syntax.Fun, owner, Tree.Ident(loopBms.nme), erasedType = N) val loopAnnots = if f.inline then Annot.Inline :: Annot.Private :: Nil else Annot.Private :: Nil diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala index 817a382123..6ba54ab0ed 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/WorkerWrapper.scala @@ -39,7 +39,7 @@ class WorkerWrapper body.size <= cfg.altSmallThreshold private def freshParam(param: Param, mapping: collection.mutable.Map[Symbol, Symbol]): Param = - val freshSym = new VarSymbol(param.sym.id) + val freshSym = new VarSymbol(param.sym.id, erasedType = param.sym.erasedType) mapping(param.sym) = freshSym Param(param.flags, freshSym, param.sign, param.modulefulness).withSignTypeOf(param) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala index 9f1f90ee8a..87e4638a56 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala @@ -114,7 +114,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val name = path.mkFunName + s"$$${f.nme}" f -> ( new BlockMemberSymbol(name, Nil, true), - new TermSymbol(Fun, N, Tree.Ident(name))) + new TermSymbol(Fun, N, Tree.Ident(name), erasedType = N)) .toMap) end mkNewPolyFnSyms @@ -150,7 +150,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val clsNme = selInfo.selectsFrom.ctorClsName fieldSym.getOrElseUpdate( selInfo.field, - new VarSymbol(Tree.Ident(s"${clsNme}_${selInfo.field.fieldName}"))) + new VarSymbol(Tree.Ident(s"${clsNme}_${selInfo.field.fieldName}"), erasedType = N)) ) // ctor dest branch function computations @@ -178,7 +178,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val scrutName = dest._1.getReferredSym.nme val branchFnNme = s"${dest.instId.mkFunName}$$$scrutName$branchName" new BlockMemberSymbol(branchFnNme, Nil, true) - -> new TermSymbol(Fun, N, Tree.Ident(branchFnNme)) + -> new TermSymbol(Fun, N, Tree.Ident(branchFnNme), erasedType = N) ) // compute the function parameters corresponding to ctor fields of branch funs branchFunParamFieldSyms.getOrElseUpdate( @@ -193,7 +193,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): completeArgs.map: selField => selsInfos.get(selField) match case Some(selId) => branchSelSyms(selId) - case None => VarSymbol(Tree.Ident(s"_${selField.fieldName}")) + case None => VarSymbol(Tree.Ident(s"_${selField.fieldName}"), erasedType = N) ) val (parents, _) = getParentLabelOrMatchesAndRestBefore(dest.exprId) @@ -207,7 +207,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): locally: val restFunName = dest.instId.mkFunName + s"$$${nme}_rest" new BlockMemberSymbol(restFunName, Nil, true) - -> new TermSymbol(Fun, N, Tree.Ident(restFunName)) + -> new TermSymbol(Fun, N, Tree.Ident(restFunName), erasedType = N) ) val (ps, restBeforeParent) = getParentLabelOrMatchesAndRestBefore(matchOrLabelId) restOriginalBodiesAndParentRest.getOrElseUpdate( @@ -390,7 +390,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val ctorInfo = solver.fusingCtorInfo(ctorDtorId) val clsNme = ctorInfo.ctor.ctorClsName ctorInfo.args.unzip._1.map: f => - new TempSymbol(N, s"${clsNme}_${f.fieldName}") + new TempSymbol(N, erasedType = N, s"${clsNme}_${f.fieldName}") end mkCtorFieldSyms solver.finalCtorDests.get(ctor.uid.concreteId) match @@ -422,7 +422,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): case ctor@CtorProducer(_, args, _) if solver.finalCtorDests.isDefinedAt(ctor.uid.concreteId) => assert(args.isEmpty) val callBranchFun = mkCall(branchFunSyms(ctorWhichBranch(ctor.uid.concreteId)), Nil) - val lambdaSym = new TempSymbol(N, "deforest$lam") + val lambdaSym = new TempSymbol(N, erasedType = N, "deforest$lam") Scoped( Set.single(lambdaSym), Assign( @@ -531,7 +531,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): ParamList( pl.flags, pl.params.map: p => - val newSym = new VarSymbol(Tree.Ident(p.sym.name)) + val newSym = new VarSymbol(Tree.Ident(p.sym.name), erasedType = N) refreshParamMap(p.sym) = newSym Param(p.flags, newSym, p.sign, p.modulefulness), pl.restParam) @@ -551,7 +551,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): val actualBody = Begin( new Rewriter(instId).applyBlock(ogBody), Return(mkCall(restFunSym, restFunArgs))) - val refreshedFvSymbols = dtorBranchFnFvs(branchId._1).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}"))) + val refreshedFvSymbols = dtorBranchFnFvs(branchId._1).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}"), erasedType = N)) val bodyWithCorrectSymbols = refreshExtractedBody(refreshedFvSymbols.toMap, actualBody) FunDefn(N, bms, tSym, branchFunParamFieldSyms(branchId).asParamList :: refreshedFvSymbols.unzip._2.asParamList :: Nil, @@ -575,7 +575,7 @@ class DeforestRewriter(val solver: DeforestFusionSolver)(using Raise): Return(mkCall(parentFunSym, parentFunFvs))) case None => Begin(transformedOgBody, Return(Value.Lit(Tree.UnitLit(true)))) - val refreshedFvSymbols = restFnFvs(restFunId).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}"))) + val refreshedFvSymbols = restFnFvs(restFunId).map(s => s -> new VarSymbol(Tree.Ident(s"fv_${s.nme}"), erasedType = N)) val bodyWithCorrectSymbols = refreshExtractedBody(refreshedFvSymbols.toMap, actualBody) FunDefn(N, bms, tsym, refreshedFvSymbols.unzip._2.asParamList :: Nil, bodyWithCorrectSymbols)(N, annotations = PrivateModifier :: Nil) end newRestFuns diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index b26ecadade..b3a2df4f91 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -76,7 +76,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: case _ => false private def getPrivateAccessorSymbol(ts: semantics.TermSymbol): semantics.TempSymbol = - privateAccessorSymbols.getOrElseUpdate(ts, semantics.TempSymbol(N, s"${ts.name}$$accessorSymbol")) + privateAccessorSymbols.getOrElseUpdate(ts, semantics.TempSymbol(N, erasedType = N, s"${ts.name}$$accessorSymbol")) private def selectPrivateField(ts: semantics.TermSymbol, loc: Opt[Loc])(using Raise, Scope): Opt[Document] = ts.owner.collect: @@ -472,7 +472,7 @@ class JSBuilder(using Config, TL, State, Ctx) extends CodeBuilder: pubFlds.collect: case (_, sym) if sym.k is MutVal => sym -> TermSymbol( - syntax.LetBind, S(isym), Tree.Ident(sym.nme)) + syntax.LetBind, S(isym), Tree.Ident(sym.nme), erasedType = sym.erasedType) val allPrivFlds = privFlds ++ mutPubFields.map(_._2) val privDecls = allPrivFlds.map: fld => val nme = isym.privatesScope.allocateOrGetName(fld) @@ -1050,7 +1050,7 @@ trait JSBuilderArgNumSanityChecks(using TL, Config, Elaborator.State) override def checkSelections: Bool = instrument override def freezeDefinitions: Bool = instrument - val functionParamVarargSymbol = semantics.TempSymbol(N, "args") + val functionParamVarargSymbol = semantics.TempSymbol(N, erasedType = N, "args") override def setupFunction(name: Option[Str], params: ParamList, body: Block, isLambda: Bool)(using Raise, Scope): (Document, Document) = // * We used to instrument `fun f(x, y) = x + y` into something like diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala index 3c11177296..4f44107a4a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala @@ -191,6 +191,10 @@ final class SessionExportCtx( * The expression of the function body. * @param exportName * Optional export name. + * @param typedParams + * Whether parameter slots should be declared with their `erasedType`-derived Wasm type (via + * `ValueSymbol.paramRefType`) rather than uniformly `anyref`. Only safe for top-level free functions, which are not + * subject to the shared vtable calling convention; `false` for everything else (methods, ctors, `init`). */ class FuncInfo( val sym: BlockMemberSymbol | TempSymbol, @@ -201,14 +205,15 @@ class FuncInfo( val body: Expr, val exportName: Opt[Str], val wrapId: Opt[Str] -> Opt[Str] = N -> N, -)(using Ctx, Raise) extends ToWat: + val typedParams: Bool = false, +)(using Ctx, Raise, State) extends ToWat: /** Symbolic identifier for the function. */ val id = SymIdx(summon[Ctx].funcScp.allocateOrGetNameWrapped(sym, wrapId)) /** Returns the type of this function as a [[SignatureType]]. */ def getSignatureType: SignatureType = SignatureType( - params = params.map((_, paramIdx) => WasmParam(paramIdx, RefType.anyref)), + params = params.map((sym, paramIdx) => WasmParam(paramIdx, if typedParams then sym.paramRefType else RefType.anyref)), results = resultTypes, ) @@ -220,7 +225,7 @@ class FuncInfo( getSignatureType.toWat.surroundUnlessEmpty(doc" ") } #{ ${ locals.map: p => - doc"(local ${p._2.toWat} ${RefType.anyref.toWat})" + doc"(local ${p._2.toWat} ${p._1.localRefType.toWat})" .mkDocument(doc" # ").surroundUnlessEmpty(doc" # ") } # ${body.toWat} #} )""" end FuncInfo @@ -360,8 +365,11 @@ object FunctionCtx: * The parameters of this function. * @param thisSym * The implicit `this` parameter symbol if this function is generated from a non-static method, or `N` otherwise. + * @param typedParams + * Whether parameter slots should be declared/loaded with their `erasedType`-derived Wasm type rather than uniformly + * `anyref`. */ -class FunctionCtx(_params: Ls[ParamList], thisSym: Opt[InnerSymbol])(using Raise, State): +class FunctionCtx(_params: Ls[ParamList], thisSym: Opt[InnerSymbol], typedParams: Bool = false)(using Raise, State): /** [[Scope]] for generating WAT identifiers of locals. */ private[text] val localScp = Scope.empty(Scope.Cfg.default) @@ -404,6 +412,21 @@ class FunctionCtx(_params: Ls[ParamList], thisSym: Opt[InnerSymbol])(using Raise */ def locals: Seq[ValueSymbol -> SymIdx] = _locals.map(l => l -> SymIdx(localScp.lookup_!(l, N))).toSeq + /** The declared Wasm reference type of the param/local slot for `sym`. + * + * Parameters are `anyref` by default: their declared type is fixed by the shared call/vtable + * calling convention. + * + * When `typedParams` is set (e.g. when compiling free functions), parameter slots instead derive their type from + * [[ValueSymbol.paramRefType]]. + * + * Local slots always derive their type from the symbol's erased type via [[localRefType]]. + */ + def slotRefType(sym: ValueSymbol)(using Ctx): RefType = + if params.exists(_._1 == sym) then + if typedParams then sym.paramRefType else RefType.anyref + else sym.localRefType + /** Pushes a label target for the dynamic extent of `body` and pops it afterwards. * * The `body` function is given a [[LabelTarget]] containing the `break` and `continue` labels corresponding to @@ -444,8 +467,9 @@ end FunctionCtx def genFuncBody[T]( params: Ls[ParamList], thisSym: Opt[InnerSymbol], + typedParams: Bool = false, )(mkBody: FunctionCtx ?=> T)(using Raise, State): T -> FunctionCtx = - val funcCtx = FunctionCtx(params, thisSym) + val funcCtx = FunctionCtx(params, thisSym, typedParams) val result = mkBody(using funcCtx) result -> funcCtx @@ -696,7 +720,7 @@ class Ctx(using State) extends ToWat: val id = SymIdx(name) memories = memories + (id -> - Import(module, name, ExternType.Mem(MemType(Limits(minPages)), sym = TempSymbol(N, name)))) + Import(module, name, ExternType.Mem(MemType(Limits(minPages)), sym = TempSymbol(N, erasedType = N, name)))) cachedMemoryImport(key) = SymIdx(name) end match end ensureMemoryImport diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 0904a24735..4532dc5ac5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -26,6 +26,42 @@ extension (instr: FoldedInstr) private def mnemonicPrefix: Opt[Str] = instr.mnemonic.split('.').optionUnless(_.size == 1).map(_.head) +extension (et: ErasedType) + /** Returns the corresponding Wasm type for this [[`ErasedType`]]. */ + private[text] def wasmType(using Ctx): Opt[RefType] = + import Ctx.ctx + et match + case ErasedType.Primitive(PrimitiveType.Int | PrimitiveType.Int31 | PrimitiveType.Bool) => + S(RefType.i31ref) + case ErasedType.AnyRef(_, csym: ClassLikeSymbol) => + csym.asBlkMember.flatMap(ctx.getType).map(RefType(_, nullable = false)) + case _ => N + +extension (sym: ValueSymbol) + /** The Wasm reference type a *local* slot for `sym` should be declared with. + * + * Use [[`FunctionCtx.slotRefType`]] for parameter slots, which handles `anyref` widening due to virtual dispatch + * calling conventions. + */ + private[text] def localRefType(using Ctx, State): RefType = + import Ctx.ctx + sym match + case isym: InnerSymbol => + val structSym = + if isym eq State.unitSymbol then S(State.unitBlockMemberSymbol) + else isym.asBlkMember + structSym.flatMap(ctx.getType).map(RefType(_, nullable = false)).getOrElse(RefType.anyref) + case s: HasErasedType => + s.erasedType.flatMap(_.wasmType).getOrElse(RefType.anyref) + case _ => RefType.anyref + + /** The Wasm reference type a parameter slot for `sym` should be declared with, if typed parameters are enabled. */ + private[text] def paramRefType(using Ctx, State): RefType = + sym match + case s: HasErasedType => + s.erasedType.flatMap(_.wasmType).getOrElse(RefType.anyref) + case _ => RefType.anyref + extension (exprs: Seq[Expr]) /** Merges a sequence of expressions to a single `block` expression if needed. */ def mergeAsBlock: Opt[Expr] = @@ -73,13 +109,13 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private val baseObjectSym: BlockMemberSymbol = BlockMemberSymbol("Object", Nil) /** Synthetic field symbol for the object-header pointer to a class's shared RTTI object. */ - private val typeInfoFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$typeinfo")) + private val typeInfoFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$typeinfo"), erasedType = N) /** Synthetic field symbol for the runtime class tag stored in RTTI. */ - private val tagFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$tag")) + private val tagFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$tag"), erasedType = S(ErasedType.Primitive(PrimitiveType.Int))) /** Synthetic field symbol for the direct-parent RTTI reference used by runtime subtype checks. */ - private val parentFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$parent")) + private val parentFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$parent"), erasedType = N) private case class StringLitInfo(offset: Int, byteLen: Int, watBytes: Str) private val stringLits: LinkedHashMap[Str, StringLitInfo] = LinkedHashMap.empty @@ -138,6 +174,28 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: require(expr.resultTypes.size == 1, "expected single-result expression for cast") if expr.resultType.contains(target) then expr else ref.cast(expr, target) + /** Casts an expression to `target` type if the result type is a supertype of `target`. */ + private def downcastConserve(expr: Expr, target: RefType): Expr = + require(expr.resultTypes.size == 1, "expected single-result expression for cast") + target match + case rt: RefType if rt.heapType =/= HeapType.Any && expr.resultType.contains(RefType.anyref) => + castConserve(expr, rt) + case _ => expr + + /** Casts each argument in `wasmArgs` down to the corresponding declared parameter type read from `funcTypeInfo`, + * narrowing `anyref` -> a concrete typed parameter. + * + * Note that this function does not cast an argument that is already narrower than the declared parameter type, + * since it is illegal to upcast using `ref.cast`. + */ + private def castArgsToParams(wasmArgs: Seq[Expr], funcTypeInfo: TypeInfo): Seq[Expr] = + val declParams = funcTypeInfo.compType.asInstanceOf[FunctionType].sigType.params + wasmArgs.zip(declParams).map: (arg, p) => + p.valtype match + case rt: RefType if rt.heapType =/= HeapType.Any && arg.resultType.contains(RefType.anyref) => + castConserve(arg, rt) + case _ => arg + /** Returns the default Wasm value for one struct field when eagerly constructing an object instance. */ private def defaultStructFieldValue(field: Field)(using Ctx, Raise): Expr = field.ty match case refTy: RefType if refTy.nullable => ref.`null`(refTy.heapType) @@ -151,9 +209,9 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: scrutTypeInfo: Expr, targetTypeInfo: Expr, )(using Ctx, FunctionCtx, Raise): Expr = - val currentTmp = mkTempLocal("currentTypeInfo") - val targetTmp = mkTempLocal("targetTypeInfo") - val resultTmp = mkTempLocal("typeInfoMatch") + val currentTmp = mkTempLocal("currentTypeInfo", erasedType = N) + val targetTmp = mkTempLocal("targetTypeInfo", erasedType = N) + val resultTmp = mkTempLocal("typeInfoMatch", erasedType = S(ErasedType.Primitive(PrimitiveType.Bool))) funcCtx.withLabel(LabelSymbol(N, "typeInfo"), hasContinueLabel = true): case LabelTarget(breakLabel, S(continueLabel)) => Seq( @@ -303,9 +361,9 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val ctorCall = call( funcidx = ctx.getFunc_!(clsLikeDefn.sym), operands = Seq.empty, - returnTypes = Seq(Result(RefType.anyref)), + returnTypes = Seq(Result(globalTy)), ) - ctx.addSingletonInitAction(global.set(globalIdx, ref.cast(ctorCall, globalTy))) + ctx.addSingletonInitAction(global.set(globalIdx, ctorCall)) end registerSingletonInit /** Collects only top-level class definitions in `block`. */ @@ -487,12 +545,13 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: defn: ClsLikeDefn, suffix: Str, params: Seq[ValueSymbol -> SymIdx], + results: Seq[Result], )(using Ctx, Raise): TypeIdx = ctx.addType(TypeInfo( - sym = TempSymbol(N, defn.sym.nme), + sym = TempSymbol(N, erasedType = N, defn.sym.nme), FunctionType( params = params.map(p => WasmParam(p._2, RefType.anyref)), - results = Seq(Result(RefType.anyref)), + results = results, ), objectTag = N, wrapId = N -> S(suffix), @@ -510,7 +569,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private def virtualMethodFuncType(arity: Int)(using Ctx, Raise): TypeIdx = ctx.getOrCreateWasmIntrinsicType(WasmIntrinsicType.VirtualMethod(arity)): ctx.addType(TypeInfo( - sym = TempSymbol(N, s"virtual$arity"), + sym = TempSymbol(N, erasedType = N, s"virtual$arity"), compType = virtualMethodSignature(arity), objectTag = N, )) @@ -525,17 +584,19 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: defn: ClsLikeDefn, suffix: Str, params: Seq[ValueSymbol -> SymIdx], + results: Seq[Result], sym: BlockMemberSymbol, exportName: Opt[Str], )(using Ctx, Raise): Unit = - val funcTy = declareClassFuncType(defn, suffix, params) - predeclareClassFuncWithType(defn, suffix, params, sym, exportName, funcTy) + val funcTy = declareClassFuncType(defn, suffix, params, results) + predeclareClassFuncWithType(defn, suffix, params, results, sym, exportName, funcTy) /** Registers a placeholder class-associated function using a predeclared Wasm function type. */ private def predeclareClassFuncWithType( defn: ClsLikeDefn, suffix: Str, params: Seq[ValueSymbol -> SymIdx], + resultTypes: Seq[Result], sym: BlockMemberSymbol, exportName: Opt[Str], funcTy: TypeIdx, @@ -545,7 +606,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: wrapId = if sym.asClsOrMod.isDefined then (N -> S("ctor")) else (S(defn.sym.nme) -> N), typeUse = TypeUse(funcTy), params = params, - resultTypes = Seq(Result(RefType.anyref)), + resultTypes = resultTypes, locals = Seq.empty, body = ref.`null`(ctx.getType_!(defn.sym)), exportName = exportName, @@ -572,17 +633,18 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val initParams = (defn.isym -> SymIdx("this")) +: pl.params.map: p => p.sym -> SymIdx(p.sym.nme) - predeclareClassFunc(defn, "init", initParams, initFuncSym(defn.sym), N) + predeclareClassFunc(defn, "init", initParams, Seq(Result(RefType.anyref)), initFuncSym(defn.sym), N) /** Declares one top-level class constructor. */ private def predeclareClassConstructor(defn: ClsLikeDefn)(using Ctx, Raise): Unit = + val typeIdx = ctx.getType_!(defn.sym) val ctorParams = classCtorParamList(defn).params.map: p => p.sym -> SymIdx(p.sym.nme) val ctorExportName = defn.sym .optionIf: sym => !(defn.k is syntax.Obj) && sym.nameIsMeaningful .map(_.nme) - predeclareClassFunc(defn, "ctor", ctorParams, defn.sym, ctorExportName) + predeclareClassFunc(defn, "ctor", ctorParams, Seq(Result(RefType(typeIdx, nullable = false))), defn.sym, ctorExportName) /** Registers all Wasm pre-declarations needed for one top-level class, in dependency order. */ private def predeclareClass(defn: ClsLikeDefn)(using Ctx, Raise, SessionExportCtx): Unit = @@ -672,7 +734,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case func: SessionFunc => // If the function symbol comes from a class or module, generate a TempSymbol to avoid symbol collision with // the class/module itself - val funcTySym = TempSymbol(N, func.sym.nme) + val funcTySym = TempSymbol(N, erasedType = N, func.sym.nme) val typeIdx = ctx.addType(TypeInfo(sym = funcTySym, wrapId = func.wrapId, compType = func.funcType, objectTag = N)) ctx.addFunctionImport(WasmImport( @@ -708,7 +770,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: typeInfoTypeIdxs(cls.sym) = typeInfoTypeIdx val globalExtern = ExternType.Global( GlobalType(RefType(typeInfoTypeIdx, nullable = false), mutable = false), - TempSymbol(N, cls.sym.nme), + TempSymbol(N, erasedType = N, cls.sym.nme), wrapId = N -> S("typeinfo"), ) val globalIdx = ctx.addGlobalImport(WasmImport( @@ -738,7 +800,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val newSlotFields = currentVirtualMethods.zipWithIndex.drop(parentVirtualMethodCount).map: (methodSym, slot) => val methodDefn = defn.methods.find(_.sym == methodSym).get val arity = 1 + methodDefn.params.headOption.fold(0)(_.params.size) - val fieldSym = TermSymbol(syntax.MutVal, owner = N, Ident(s"slot$slot")) + val fieldSym = TermSymbol(syntax.MutVal, owner = N, Ident(s"slot$slot"), erasedType = N) fieldSym -> Field( RefType(virtualMethodFuncType(arity), nullable = true), mutable = true, @@ -746,7 +808,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ) val typeInfoType = ctx.addType(TypeInfo( - sym = TempSymbol(N, defn.sym.nme), + sym = TempSymbol(N, erasedType = N, defn.sym.nme), compType = StructType(fields = inheritedFields ++ newSlotFields, parents = Seq(parentTypeInfoIdx)), objectTag = N, wrapId = N -> S("typeinfo"), @@ -796,12 +858,13 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ownerCls, methodDefn.sym.nme, methodParams, + Seq(Result(RefType.anyref)), methodDefn.sym, N, virtualMethodFuncType(methodParams.size), ) case N => - predeclareClassFunc(ownerCls, methodDefn.sym.nme, methodParams, methodDefn.sym, N) + predeclareClassFunc(ownerCls, methodDefn.sym.nme, methodParams, Seq(Result(RefType.anyref)), methodDefn.sym, N) /** Declares placeholders for all methods on one top-level class. */ private def predeclareClassMethods(defn: ClsLikeDefn)(using Ctx, Raise): Unit = @@ -818,7 +881,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: /** Gets (and caches) the exception tag used for MLX `throw`. */ private def exnTagIdx(using Ctx, Raise): TagIdx = - val sym = TempSymbol(N, "mlx_exn") + val sym = TempSymbol(N, erasedType = N, "mlx_exn") ctx.getOrCreateWasmIntrinsicTag( "mlx_exn", ctx.addTag(TagInfo( @@ -868,7 +931,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: module = ExternIntrinsics.SystemModule, name = ExternIntrinsics.StringFromUtf16ImportName, ): - val importTySym = TempSymbol(N, ExternIntrinsics.StringFromUtf16ImportName) + val importTySym = TempSymbol(N, erasedType = N, ExternIntrinsics.StringFromUtf16ImportName) val importTy = ctx.addType(TypeInfo( sym = importTySym, compType = FunctionType( @@ -902,8 +965,8 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: /** Allocates a fresh temp local (typed `anyref`) and returns its `LocalIdx`. */ - private def mkTempLocal(base: Str)(using Ctx, FunctionCtx, Raise): LocalIdx = - funcCtx.addLocal(TempSymbol(N, base)) + private def mkTempLocal(base: Str, erasedType: Opt[ErasedType])(using Ctx, FunctionCtx, Raise): LocalIdx = + funcCtx.addLocal(TempSymbol(N, erasedType, base)) /** Binds constructor self (`thisSym`) to the Wasm local name `this` in the current function context. */ @@ -1007,7 +1070,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val elemType = RefType.anyref val mutArrayType = tupleArrayType(true) val immArrayType = tupleArrayType(false) - val tupleTmp = mkTempLocal("tuple") + val tupleTmp = mkTempLocal("tuple", erasedType = N) val tupleIsMutable = ref.test(local.tee(tupleTmp, tupleExpr), RefType(mutArrayType, nullable = true)) val tupleValue = local.get(tupleTmp, RefType.anyref) val mutableBranch = @@ -1052,7 +1115,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: extraInfo = S(errExtra), ) - val idxTmp = mkTempLocal("idx") + val idxTmp = mkTempLocal("idx", erasedType = S(ErasedType.Primitive(PrimitiveType.Int31))) tupleRef => val storeIdx = local.set(idxTmp, ref.i31(idxI32)) @@ -1090,7 +1153,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: lastWords(s"ValueSymbol `$ts` (${ts.getClass.getSimpleName}) cannot be resolved as a variable") case l => funcCtx.lookupLocal(l) match - case S(localIdx) => local.get(localIdx, RefType.anyref) + case S(localIdx) => local.get(localIdx, funcCtx.slotRefType(l)) case N if ctx.containsGlobal(l) => global.get(ctx.getGlobal_!(l), ctx.getGlobalType_!(l).globalType.valType) case _ => @@ -1155,7 +1218,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ctx.getVirtualTable(ownerCls).flatMap(_.virtualMethodSlots.get(methodSym)) match case S(slot) => val ownerTypeInfoIdx = typeInfoTypeIdxs(ownerCls) - val receiverTmp = mkTempLocal("receiver") + val receiverTmp = mkTempLocal("receiver", erasedType = N) val receiverExpr = local.set(receiverTmp, result(qual)) val receiverRef = local.get(receiverTmp, RefType.anyref) val ownerTypeInfoRef = ref.cast( @@ -1222,9 +1285,10 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: singletonInfoFor(sym) match case S(info) => singletonGlobalGet(info) case N => - // TODO(Derppening): Remove `ref.cast` once erased-typed IR is implemented - ref.cast( - local.get(funcCtx.lookupLocal_!(sym, sym.toLoc), RefType.anyref), + // The cast is necessary if `$this` is a parameter rather than a local, since `$this` parameters are typed as + // `anyref` to support virtual dispatch + castConserve( + local.get(funcCtx.lookupLocal_!(sym, sym.toLoc), funcCtx.slotRefType(sym)), RefType( sym.asBlkMember.fold(baseObjectTypeIdx)(ctx.getType_!(_)), nullable = false, @@ -1290,7 +1354,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: extraInfo = S(fun.toString), ) val baseTypeInfo = ctx.getTypeInfo_!(ctx.getFuncTypeUse_!(baseFuncIdx).typeIdx) - val wasmArgs = args.map(argument) + val wasmArgs = castArgsToParams(args.map(argument), baseTypeInfo) call( funcidx = baseFuncIdx, @@ -1306,7 +1370,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: extraInfo = S(fun.toString), ) val baseTypeInfo = ctx.getTypeInfo_!(ctx.getFuncTypeUse_!(baseFuncIdx).typeIdx) - val wasmArgs = args.map(argument) + val wasmArgs = castArgsToParams(args.map(argument), baseTypeInfo) call( funcidx = baseFuncIdx, @@ -1456,17 +1520,16 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: Ls(msg"Class path for an Instantiate(...) expression must be resolved" -> cls.toLoc), extraInfo = S(s"Block IR of `cls` expression: ${cls.toString}"), ) - val ctorClsBlkSym = ctorClsSym.asBlkMember match - case S(sym) => sym - case N => lastWords( - s"Expected resolved class for an Instantiate(...) expression to be a BlockMemberSymbol, but got ${ - ctorClsSym.getClass.getName - }", - ) - val ctorFuncIdx = ctx.getFunc(ctorClsBlkSym) match - case S(idx) => idx - case N => lastWords(s"Missing constructor definition for class ${ctorClsBlkSym.toString}") - call(funcidx = ctorFuncIdx, as.map(argument), Seq(Result(RefType.anyref))) + val ctorClsBlkSym = ctorClsSym.asBlkMember.getOrElse: + lastWords: + s"Expected resolved class for an Instantiate(...) expression to be a BlockMemberSymbol, but got ${ + ctorClsSym.getClass.getName + }" + val ctorFuncIdx = ctx.getFunc(ctorClsBlkSym).getOrElse: + lastWords(s"Missing constructor definition for class ${ctorClsBlkSym.toString}") + val ctorClsTypeIdx = ctx.getType(ctorClsBlkSym).getOrElse: + lastWords(s"Missing class definition for class ${ctorClsBlkSym.toString}") + call(funcidx = ctorFuncIdx, as.map(argument), Seq(Result(RefType(ctorClsTypeIdx, nullable = false)))) case Tuple(mut, elems) => val tupleValues = elems.map(argument) @@ -1497,7 +1560,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ctx.addFunctionImport(WasmImport( ExternIntrinsics.SystemModule, name, - ExternType.Func(TypeUse(typeIdx), TempSymbol(N, name), wrapId = N -> N), + ExternType.Func(TypeUse(typeIdx), TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), name), wrapId = N -> N), )) /** Creates the intrinsic definition for `name`. @@ -1512,7 +1575,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private def declareIntrinsicType(name: Str)(using Ctx, Raise): TypeIdx = ctx.addType(TypeInfo( - sym = TempSymbol(N, name), + sym = TempSymbol(N, erasedType = N, name), compType = FunctionType( params = intrinsicParamSuffixes(name).map(nme => WasmParam(SymIdx(nme), RefType.anyref)), results = Seq(Result(RefType.anyref)), @@ -1555,7 +1618,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: )(using Ctx, Raise): FuncIdx = val funcTy = declareIntrinsicType(name) val funcInfo = FuncInfo( - sym = TempSymbol(N, name), + sym = TempSymbol(N, erasedType = N, name), typeUse = TypeUse(funcTy), params = params, locals = Seq.empty, @@ -1605,9 +1668,10 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: /** Creates parameters for an intrinsic. */ + // TODO(Derppening): WTF? Remove `name` and add erasedType to `params` private def mkIntrinsicParams(name: Str, suffixes: Ls[Str]): Ls[TempSymbol -> SymIdx] = suffixes.map: suffix => - val sym = TempSymbol(N, suffix) + val sym = TempSymbol(N, erasedType = N, suffix) sym -> SymIdx(suffix) /** Loads the local `name` as an `anyref`. @@ -1646,11 +1710,14 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case Assign(l: ValueSymbol, r, rst) => val lExpr = getVar(l, l.toLoc) val rExpr = result(r) + val rExprCasted = lExpr.resultType match + case S(rt: RefType) => downcastConserve(rExpr, rt) + case _ => rExpr val assignExpr = lExpr.mnemonicPrefix match case S("global") => - global.set(lExpr.instrargs(0).asInstanceOf[GlobalIdx], rExpr) + global.set(lExpr.instrargs(0).asInstanceOf[GlobalIdx], rExprCasted) case S("local") => - local.set(lExpr.instrargs(0).asInstanceOf[LocalIdx], rExpr) + local.set(lExpr.instrargs(0).asInstanceOf[LocalIdx], rExprCasted) case _ => lastWords( s"Expected `global.*` or `local.*` when compiling instruction for `$l`, but got ${lExpr.mnemonic}", @@ -1732,11 +1799,15 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case N => val localStorageSym = defn.sym val symExpr = getVar(localStorageSym, localStorageSym.toLoc) + val pExpr = result(p) + val pExprCasted = symExpr.resultType match + case S(rt: RefType) => downcastConserve(pExpr, rt) + case _ => pExpr val defineExpr = symExpr.mnemonicPrefix match case S("global") => - global.set(symExpr.instrargs(0).asInstanceOf[GlobalIdx], result(p)) + global.set(symExpr.instrargs(0).asInstanceOf[GlobalIdx], pExprCasted) case S("local") => - local.set(symExpr.instrargs(0).asInstanceOf[LocalIdx], result(p)) + local.set(symExpr.instrargs(0).asInstanceOf[LocalIdx], pExprCasted) case _ => lastWords( s"Expected `global.*` or `local.*` when compiling definition for `$sym`, but got ${symExpr.mnemonic}", @@ -1771,13 +1842,13 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val result = pss.foldRight(bod): case (ps, block) => Return(Lambda(ps, block)(Nil)) - val (bodyWat, fnCtx) = setupFunction(N, ps, result) + val (bodyWat, fnCtx) = setupFunction(N, ps, result, typedParams = true) if sym.nameIsMeaningful then val funcTy = ctx.addType( TypeInfo( - sym = TempSymbol(N, sym.nme), + sym = TempSymbol(N, erasedType = N, sym.nme), compType = FunctionType( - params = fnCtx.params.map(p => WasmParam(p._2, RefType.anyref)), + params = fnCtx.params.map(p => WasmParam(p._2, p._1.paramRefType)), results = Seq.fill(bodyWat.resultTypes.length)(Result(RefType.anyref)), ), objectTag = N, @@ -1792,6 +1863,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: locals = fnCtx.locals, body = bodyWat, exportName = sym.optionIf(_.nameIsMeaningful).map(_.nme), + typedParams = true, ) ctx.addFunc(funcInfo) if summon[SessionExportCtx].shouldExport(defn.sym) then @@ -1870,9 +1942,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: Seq( local.set(thisVar, struct.`new`(typeref, instanceFields)), drop(initCall), - // TODO(Derppening): Restore once we fix custom reftypes - // `return`(S(local.get(thisVar, RefType(typeref, nullable = false)))), - `return`(S(local.get(thisVar, RefType.anyref))), + `return`(S(local.get(thisVar, RefType(typeref, nullable = false)))), ).mergeAsBlock_! val predeclaredInit = ctx.getFuncInfo_!(initFuncRef) @@ -1950,7 +2020,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: funcType = FunctionType( SignatureType( params = ctorFnCtx.params.map(p => WasmParam(p._2, RefType.anyref)), - results = Seq(Result(RefType.anyref)), + results = ctorCode.resultTypes.map(ty => Result(ty.asValType_!)), ), ), )) @@ -2059,11 +2129,11 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case Match(scrut, arms, dflt, rst) => val tailMode = rst.isInstanceOf[End] val matchResLocal = - if tailMode then S(mkTempLocal("matchRes")) + if tailMode then S(mkTempLocal("matchRes", erasedType = N)) else N val scrutLocalResult = scrut match case _: (Value.RefLike | Value.Lit) => N - case _ => S(mkTempLocal("scrut")) + case _ => S(mkTempLocal("scrut", erasedType = N)) val scrutInitExpr = scrutLocalResult.map: scrutLocal => local.set(scrutLocal, result(scrut)) @@ -2323,7 +2393,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val entrySym = BlockMemberSymbol("entry", Nil) val entryFnTy = ctx.addType(TypeInfo( - sym = TempSymbol(N, entrySym.nme), + sym = TempSymbol(N, erasedType = N, entrySym.nme), FunctionType(params = Seq.empty, results = Seq(Result(RefType.anyref))), objectTag = N, )) @@ -2344,19 +2414,19 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: offset = i32.const(lit.offset), bytes = Seq(lit.watBytes), memuse = N, - sym = TempSymbol(N, s.take(WatBuilder.StringConstantIdentMaxLength)), + sym = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), s.take(WatBuilder.StringConstantIdentMaxLength)), )) val initActions = ctx.getSingletonInitActions if initActions.nonEmpty then val initTy = ctx.addType(TypeInfo( - sym = TempSymbol(N, "start"), + sym = TempSymbol(N, erasedType = N, "start"), compType = FunctionType(params = Seq.empty, results = Seq.empty), objectTag = N, )) val initBody = initActions.mergeAsBlock_! val initFn = ctx.addFunc(FuncInfo( - sym = TempSymbol(N, "start"), + sym = TempSymbol(N, erasedType = N, "start"), typeUse = TypeUse(initTy), params = Seq.empty, resultTypes = initBody.resultTypes.map(ty => Result(ty.asValType_!)), @@ -2393,8 +2463,9 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: thisParam: Opt[InnerSymbol], params: ParamList, body: Block, + typedParams: Bool = false, )(using Ctx, Raise, SessionExportCtx): (Expr, FunctionCtx) = - genFuncBody(params :: Nil, thisSym = thisParam): + genFuncBody(params :: Nil, thisSym = thisParam, typedParams = typedParams): block(body).mergeAsBlock.getOrElse(nop) end WatBuilder diff --git a/hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala b/hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala index 18261d36fa..3fc9e60fef 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala @@ -142,7 +142,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): case _ => error(msg"Effect cannot be polymorphic." -> ty.toLoc :: Nil) }).getOrElse(Bot)) case f @ Term.Forall(tvs, outer, body) => - val outVar = freshOuter(outer.getOrElse(new TempSymbol(S(f), "outer")))(using ctx) + val outVar = freshOuter(outer.getOrElse(new TempSymbol(S(f), erasedType = N, "outer")))(using ctx) val nestCtx = ctx.nestWithOuter(outVar) outer.foreach(sym => nestCtx += sym -> outVar) given InvalCtx = nestCtx @@ -204,7 +204,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): typeAndSubstType(ty, pol = true)(using Map.empty) private def instantiate(ty: PolyType)(using ctx: InvalCtx): GeneralType = - ty.instantiate(infVarState.nextUid, freshEnv(new TempSymbol(N, "env")), ctx.lvl)(tl) + ty.instantiate(infVarState.nextUid, freshEnv(new TempSymbol(N, erasedType = N, "env")), ctx.lvl)(tl) private def extrude(ty: GeneralType)(using ctx: InvalCtx, pol: Bool, cctx: CCtx): GeneralType = ty match case ty: Type => solver.extrude(ty)(using ctx.lvl, pol, HashMap.empty) @@ -240,7 +240,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): nestCtx &= (sym, tv, sk) (tv, sk) val (bodyTy, ctxTy, eff) = typeCode(body) - val res = freshVar(new TempSymbol(S(f), "ctx"))(using ctx) + val res = freshVar(new TempSymbol(S(f), erasedType = N, "ctx"))(using ctx) constrain(ctxTy, bds.foldLeft[Type](res)((res, bd) => res | bd._2)) (FunType(bds.map(_._1), bodyTy, Bot), res, eff) case app @ Term.App(lhs, Term.Tup(rhs)) => @@ -250,7 +250,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): val (ty, ctx, eff) = typeCode(p.term) (ty :: res._1, res._2 | ctx, res._3 | eff) case (_, spd: Spd) => TODO(s"spread arguments in quoted code: $spd") - val resTy = freshVar(new TempSymbol(S(app), "app")) + val resTy = freshVar(new TempSymbol(S(app), erasedType = N, "app")) constrain(lhsTy, FunType(rhsTy.reverse, resTy, Bot)) // TODO: right (resTy, lhsCtx | rhsCtx, lhsEff | rhsEff) case sel @ Term.SynthSel(Term.Ref(_: TopLevelSymbol), _) if sel.symbol.isDefined => @@ -258,8 +258,8 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): (tryMkMono(opTy, sel), Bot, eff) case unq @ Term.Unquoted(body) => val (ty, eff) = typeCheck(body) - val tv = freshVar(new TempSymbol(S(unq), "cde")) - val cr = freshVar(new TempSymbol(S(unq), "ctx")) + val tv = freshVar(new TempSymbol(S(unq), erasedType = N, "cde")) + val cr = freshVar(new TempSymbol(S(unq), erasedType = N, "ctx")) constrain(tryMkMono(ty, body), InvalCtx.codeTy(tv, cr)) (tv, cr, eff) case blk @ Term.Blk(LetDecl(sym, _) :: DefineVar(sym2, rhs) :: Nil, body) @@ -270,7 +270,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): val sk = freshSkolem(sym) nestCtx &= (sym, rhsTy, sk) val (bodyTy, bodyCtx, bodyEff) = typeCode(body) - val res = freshVar(new TempSymbol(S(blk), "ctx"))(using ctx) + val res = freshVar(new TempSymbol(S(blk), erasedType = N, "ctx"))(using ctx) constrain(bodyCtx, sk | res) (bodyTy, rhsCtx | res, rhsEff | bodyEff) case Term.IfLike(_, IfLikeForm.ReturningIf, SimpleSplit.IfThenElse(cond, cons, alts)) => @@ -290,7 +290,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): ascribe(lam, sigTy) () case N => - val outer = freshOuter(new TempSymbol(S(lam), "outer"))(using ctx) + val outer = freshOuter(new TempSymbol(S(lam), erasedType = N, "outer"))(using ctx) given InvalCtx = ctx.nestWithOuter(outer) val funTyV = freshVar(sym) ctx += sym -> funTyV // for recursive functions @@ -417,7 +417,7 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): case S(sym) => val (clsTy, tv, emptyTy) = sym.defn.map(sym -> _) match case S((sym, cls)) => - (ClassLikeType(sym, cls.tparams.map(_ => freshWildcard(sym))), (freshVar(new TempSymbol(S(scrutinee), "scrut"))), ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty))) + (ClassLikeType(sym, cls.tparams.map(_ => freshWildcard(sym))), (freshVar(new TempSymbol(S(scrutinee), erasedType = N, "scrut"))), ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty))) case _ => error(msg"Cannot match ${scrutinee.toString} as ${sym.toString}" -> split.toLoc :: Nil) (Bot, Bot, Bot) @@ -543,8 +543,8 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): Left(ty) :: Right(eff) :: Nil case spd: Spd => TODO(s"spread arguments: $spd") .partitionMap(x => x) - val effVar = freshVar(new TempSymbol(S(t), "eff")) - val retVar = freshVar(new TempSymbol(S(t), "app")) + val effVar = freshVar(new TempSymbol(S(t), erasedType = N, "eff")) + val retVar = freshVar(new TempSymbol(S(t), erasedType = N, "app")) constrain(tryMkMono(funTy, t), FunType(argTy.map((tryMkMono(_, t))), retVar, effVar)) (retVar, argEff.foldLeft[Type](effVar | lhsEff)((res, e) => res | e)) @@ -748,25 +748,25 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx): given InvalCtx = nestCtx nestCtx += sym -> InvalCtx.regionTy(sk) val (res, eff) = typeCheck(body) - val tv = freshVar(new TempSymbol(S(reg), "eff"))(using ctx) + val tv = freshVar(new TempSymbol(S(reg), erasedType = N, "eff"))(using ctx) constrain(eff, tv | sk) (extrude(res)(using ctx, true), tv) case Term.RegRef(reg, value) => val (regTy, regEff) = typeCheck(reg) val (valTy, valEff) = typeCheck(value) - val sk = freshVar(new TempSymbol(S(reg), "reg")) + val sk = freshVar(new TempSymbol(S(reg), erasedType = N, "reg")) constrain(tryMkMono(regTy, reg), InvalCtx.regionTy(sk)) (InvalCtx.refTy(tryMkMono(valTy, value), sk), sk | (regEff | valEff)) case Term.SetRef(lhs, rhs) => val (lhsTy, lhsEff) = typeCheck(lhs) val (rhsTy, rhsEff) = typeCheck(rhs) - val sk = freshVar(new TempSymbol(S(lhs), "reg")) + val sk = freshVar(new TempSymbol(S(lhs), erasedType = N, "reg")) constrain(tryMkMono(lhsTy, lhs), InvalCtx.refTy(tryMkMono(rhsTy, rhs), sk)) (tryMkMono(rhsTy, rhs), sk | (lhsEff | rhsEff)) case Term.Deref(ref) => val (refTy, refEff) = typeCheck(ref) - val sk = freshVar(new TempSymbol(S(ref), "reg")) - val ctnt = freshVar(new TempSymbol(S(ref), "ref")) + val sk = freshVar(new TempSymbol(S(ref), erasedType = N, "reg")) + val ctnt = freshVar(new TempSymbol(S(ref), erasedType = N, "ref")) constrain(tryMkMono(refTy, ref), InvalCtx.refTy(ctnt, sk)) (ctnt, sk | refEff) case Term.Quoted(body) => diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala index 74b63abcbe..3e8be60e14 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala @@ -17,6 +17,7 @@ import hkmc2.Message.MessageContext import Keyword.{`and`, `case`, `do`, `else`, `if`, `is`, `let`, `or`, `set`, `then`, `while`} import hkmc2.utils.Scope +import codegen.ErasedType import SimpleSplit.* import ucs.{error, unapply} @@ -414,15 +415,16 @@ object Elaborator: def tupleSymbol: ModuleOrObjectSymbol = runtimeSymbols.tuple def strSymbol: ModuleOrObjectSymbol = runtimeSymbols.str // In JavaScript, `import` can be used for getting current file path, as `import.meta` - val importSymbol = new VarSymbol(Ident("import")) + val importSymbol = new VarSymbol(Ident("import"), erasedType = N) + @deprecated("Use the `NoSymbol` singleton instead.") val noSymbol = NoSymbol - val runtimeSymbol = TempSymbol(N, "runtime") - val definitionMetadataSymbol = TempSymbol(N, "definitionMetadata") - val prettyPrintSymbol = TempSymbol(N, "prettyPrint") - val termSymbol = TempSymbol(N, "Term") - val blockSymbol = TempSymbol(N, "Block") - val optionSymbol = TempSymbol(N, "option") - val wasmSymbol = TempSymbol(N, "wasm") + val runtimeSymbol = TempSymbol(N, erasedType = N, "runtime") + val definitionMetadataSymbol = TempSymbol(N, erasedType = N, "definitionMetadata") + val prettyPrintSymbol = TempSymbol(N, erasedType = N, "prettyPrint") + val termSymbol = TempSymbol(N, erasedType = N, "Term") + val blockSymbol = TempSymbol(N, erasedType = N, "Block") + val optionSymbol = TempSymbol(N, erasedType = N, "option") + val wasmSymbol = TempSymbol(N, erasedType = N, "wasm") val nonLocalRetHandlerTrm = val id = new Ident("NonLocalReturn") val sym = ClassSymbol(DummyTypeDef(syntax.Cls), id) @@ -451,7 +453,8 @@ object Elaborator: binary = binaryOps(op), unary = unaryOps(op), nullary = false, - functionLike = anyOps(op)) + functionLike = anyOps(op), + erasedType = N) .toMap baseBuiltins ++ aliasOps.map: case (alias, base) => alias -> baseBuiltins(base) @@ -588,10 +591,10 @@ extends Importer: )(using State): Term = val clsSym = ClassSymbol(DummyTypeDef(Cls), Ident(effectClassName)) val htds = methods.map: spec => - val valueSym = spec.valueParamName.map(nme => VarSymbol(Ident(nme))) - val resumeSym = VarSymbol(Ident("resume")) + val valueSym = spec.valueParamName.map(nme => VarSymbol(Ident(nme), erasedType = N)) + val resumeSym = VarSymbol(Ident("resume"), erasedType = N) val mtdSym = BlockMemberSymbol(spec.methodName, Nil, true) - val tsym = TermSymbol(Fun, N, Ident(spec.methodName)) + val tsym = TermSymbol(Fun, N, Ident(spec.methodName), erasedType = N) val td = TermDefinition( Fun, mtdSym, @@ -777,12 +780,12 @@ extends Importer: private def branch(using Ctx): Cfg[PartialFunction[Tree, (Ctx, SimpleSplit)]] = // Interleaved-`let` bindings like `{ x is A then 0; let x = 1; ... }`. case LetLike(Keywrd(`let`), ident: Ident, S(rhsTree), N) => - val symbol = VarSymbol(ident) + val symbol = VarSymbol(ident, erasedType = N) val head = Head.Let(symbol, term(rhsTree)) ((ctx + (ident.name -> symbol)), head ~: End) // Interleaved-`do` statements like `{ x is A then 0; do log(1); ... }`. case PrefixApp(Keywrd(`do`), rhsTree) => - (ctx, Head.Let(TempSymbol(N, "unused"), term(rhsTree)) ~: End) + (ctx, Head.Let(TempSymbol(N, erasedType = N, "unused"), term(rhsTree)) ~: End) // Although the `else`-clause marks the end of the split, we cannot // stop and still have to elaborate the remaining trees. case PrefixApp(kwTree @ Keywrd(`else`), elseTree) => @@ -882,7 +885,7 @@ extends Importer: case Term.Ref(symbol) => continuation(() => symbol.ref().withLocOf(term)) // Otherwise, we need to create a temporary symbol holding the term. case term: Term => - val symbol = TempSymbol(N, "scrut") + val symbol = TempSymbol(N, erasedType = N, "scrut") Head.Let(symbol, term) ~: continuation(() => symbol.ref()) private type TT = (Tree, Tree) @@ -1000,7 +1003,7 @@ extends Importer: error else val lt = subterm(lhs) - val sym = TempSymbol(S(lt), "old") + val sym = TempSymbol(S(lt), erasedType = N, "old") Blk( LetDecl(sym, Nil) :: DefineVar(sym, lt) :: Nil, Term.Try(Blk( Term.Assgn(lt, subterm(rhs)) :: Nil, @@ -1014,7 +1017,7 @@ extends Importer: Term.Try(subterm(tryBody), subterm(finallyBody)) case (hd @ Hndl(id: Ident, c, Block(sts_), S(bod))) => ctx.nest(OuterCtx.LambdaOrHandlerBlock).givenIn: - val sym = VarSymbol(id) + val sym = VarSymbol(id, erasedType = N) log(s"Processing `handle` statement $id (${sym}) ${ctx.outer}") val derivedClsSym = ClassSymbol(Tree.DummyTypeDef(syntax.Cls), Tree.Ident(s"Handler$$${id.name}$$")) @@ -1080,20 +1083,21 @@ extends Importer: })(N).withLocOf(tree) case InfixApp(TyTup(tvs), Keywrd(Keyword.`->`), body) => val boundVars = mutable.HashMap.empty[Str, VarSymbol] - def genSym(id: Tree.Ident) = - val sym = VarSymbol(id) + def genSym(id: Tree.Ident, erasedType: Opt[ErasedType]) = + val sym = VarSymbol(id, erasedType) sym.decl = S(TyParam(FldFlags.empty, N, sym)) // TODO vce boundVars += id.name -> sym sym val syms = (tvs.collect: - case id: Tree.Ident => (genSym(id), N, N) - case InfixApp(id: Tree.Ident, Keywrd(Keyword.`extends`), ub) => (genSym(id), S(ub), N) - case InfixApp(id: Tree.Ident, Keywrd(Keyword.`restricts`), lb) => (genSym(id), N, S(lb)) - case InfixApp(InfixApp(id: Tree.Ident, Keywrd(Keyword.`extends`), ub), Keywrd(Keyword.`restricts`), lb) => (genSym(id), S(ub), S(lb)) + case id: Tree.Ident => (genSym(id, erasedType = N), N, N) + case InfixApp(id: Tree.Ident, Keywrd(Keyword.`extends`), ub) => (genSym(id, erasedType = N), S(ub), N) + case InfixApp(id: Tree.Ident, Keywrd(Keyword.`restricts`), lb) => (genSym(id, erasedType = N), N, S(lb)) + case InfixApp(InfixApp(id: Tree.Ident, Keywrd(Keyword.`extends`), ub), Keywrd(Keyword.`restricts`), lb) => + (genSym(id, erasedType = N), S(ub), S(lb)) ) val outer = (tvs.collect: - case Outer(S(name: Tree.Ident)) => genSym(name) - case Outer(N) => genSym(Tree.Ident("outer")) + case Outer(S(name: Tree.Ident)) => genSym(name, erasedType = N) + case Outer(N) => genSym(Tree.Ident("outer"), erasedType = N) ) match case ot :: Nil => S(ot) case _ :: rest => @@ -1276,8 +1280,8 @@ extends Importer: msg" add a space before ‹identifier› to make it an operator application." -> N :: Nil N - val self = VarSymbol(Ident("self")) - val args = VarSymbol(Ident("args")) + val self = VarSymbol(Ident("self"), erasedType = N) + val args = VarSymbol(Ident("args"), erasedType = N) val ps = ParamList(ParamListFlags.empty, Param(FldFlags.empty, self, N, Modulefulness.none) :: Nil, S: @@ -1350,7 +1354,7 @@ extends Importer: case Quoted(body) => Term.Quoted(subterm(body)) case Unquoted(body) => Term.Unquoted(subterm(body)) case tree @ Case(kw, _) => - val scrut = VarSymbol(Ident("caseScrut")) + val scrut = VarSymbol(Ident("caseScrut"), erasedType = N) val body = caseSplit(scrut, tree) val params = Param(FldFlags.empty, scrut, N, Modulefulness.none) :: Nil Term.Lam(PlainParamList(params), body).mkLocWith(kw) @@ -1379,10 +1383,10 @@ extends Importer: Term.Throw(subterm(body)).mkLocWith(kw) case PrefixApp(kw @ Keywrd(Keyword.`do`), InfixApp(labelId: Ident, Keywrd(Keyword.`:`), body)) => val labelSym = new LabelSymbol(N, labelId.name) - val resultSym = new TempSymbol(N, s"${labelId.name}$$result") - val nonLocalHandlerSym = TempSymbol(N, s"nonLocalHandler$$${labelId.name}") - val nonLocalBreakMethodMarker = TempSymbol(N, s"nonLocalBreakMethod$$${labelId.name}") - val nonLocalContinueMethodMarker = TempSymbol(N, s"nonLocalContinueMethod$$${labelId.name}") + val resultSym = new TempSymbol(N, erasedType = N, s"${labelId.name}$$result") + val nonLocalHandlerSym = TempSymbol(N, erasedType = N, s"nonLocalHandler$$${labelId.name}") + val nonLocalBreakMethodMarker = TempSymbol(N, erasedType = N, s"nonLocalBreakMethod$$${labelId.name}") + val nonLocalContinueMethodMarker = TempSymbol(N, erasedType = N, s"nonLocalContinueMethod$$${labelId.name}") val bodyTerm = ctx.withLabel( labelSym, resultSym, nonLocalHandlerSym, nonLocalBreakMethodMarker, nonLocalContinueMethodMarker).givenIn: subterm(body) @@ -1394,7 +1398,7 @@ extends Importer: case PrefixApp(kw @ Keywrd(Keyword.`drop`), body) => Term.Drop(subterm(body)).mkLocWith(kw) case Region(id: Ident, body) => - val sym = VarSymbol(id) + val sym = VarSymbol(id, erasedType = N) given Ctx = ctx + (id.name -> sym) Term.Region(sym, subterm(body)) case RegRef(reg, value) => Term.RegRef(subterm(reg), subterm(value)) @@ -1480,7 +1484,7 @@ extends Importer: raise(ErrorReport(msg"Illegal position for '_' placeholder." -> tree.toLoc :: Nil)) error case S(unds) => - val sym = VarSymbol(Ident("_" + unds.size)) + val sym = VarSymbol(Ident("_" + unds.size), erasedType = N) unds += sym sym.ref() case Annotated(lhs, rhs) => @@ -1726,7 +1730,7 @@ extends Importer: (base, term(rrhs)) val newAcc = rlhs match case id: Ident => - val sym = new VarSymbol(id) + val sym = new VarSymbol(id, erasedType = N) newCtx += id.name -> sym RcdField(Term.Lit(StrLit(id.name)).withLocOf(id), sym.ref(id)) :: DefineVar(sym, rhs_t) @@ -1840,7 +1844,7 @@ extends Importer: case N => N case _ if ctx.mode is Mode.Light => S(Term.Missing) case S(rhs) => S: - val nonLocalRetHandler = TempSymbol(N, s"nonLocalRetHandler$$${id.name}") + val nonLocalRetHandler = TempSymbol(N, erasedType = N, s"nonLocalRetHandler$$${id.name}") newCtx.nest(OuterCtx.Function(nonLocalRetHandler)).givenIn: newCtx ?=> val b = term(rhs)(using newCtx) if nonLocalRetHandler.directRefs.isEmpty then b else @@ -1861,7 +1865,7 @@ extends Importer: case _ => Modulefulness.none - val tsym = TermSymbol(k, owner, id) // TODO? + val tsym = TermSymbol(k, owner, id, erasedType = N) // TODO? val tdf = TermDefinition(k, sym, tsym, pss, tps, s, body, TermDefFlags.empty.copy(isMethod = isMethod), mfn, annotations, N).withLocOf(td) tsym.defn = S(tdf) @@ -1907,7 +1911,7 @@ extends Importer: case S(ts) => ts.tys.flatMap: targ => def mk(id: Ident, vce: Opt[Bool]): Ls[TyParam] = - val vs = VarSymbol(id) + val vs = VarSymbol(id, erasedType = N) val res = TyParam(FldFlags.empty, vce, vs) vs.decl = S(res) res :: Nil @@ -1981,7 +1985,7 @@ extends Importer: tsym.defn = S(fdef) fdef :: Nil else - val psym = TermSymbol(LetBind, owner, p.sym.id) + val psym = TermSymbol(LetBind, owner, p.sym.id, erasedType = N) psym.sourceAliases = p.sym.sourceAliases val decl = LetDecl(psym, Nil) val defn = DefineVar(psym, p.sym.ref()) @@ -1994,7 +1998,7 @@ extends Importer: val owner = td.symbol match case s: InnerSymbol => S(s) case _: TypeAliasSymbol => die - val psym = TermSymbol(LetBind, owner, p.sym.id) + val psym = TermSymbol(LetBind, owner, p.sym.id, erasedType = N) psym.sourceAliases = p.sym.sourceAliases val decl = LetDecl(psym, Nil) val defn = DefineVar(psym, p.sym.ref()) @@ -2238,8 +2242,8 @@ extends Importer: case N => N def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): TermSymbol | VarSymbol = - if ctx.outer.inner.isDefined then TermSymbol(k, ctx.outer.inner, id) - else VarSymbol(id) + if ctx.outer.inner.isDefined then TermSymbol(k, ctx.outer.inner, id, erasedType = N) + else VarSymbol(id, erasedType = N) def param(t: Tree, inUsing: Bool, inDataClass: Bool): Ctxl[Diagnostic \/ (Param, Opt[SpreadKind], Ls[Str])] = t.desugared.asParam(inUsing).map: @@ -2251,7 +2255,7 @@ extends Importer: new Ident(base).withLocOf(id) -> (id.name :: Nil) case N => id -> Nil - val sym = VarSymbol(canonicalId) + val sym = VarSymbol(canonicalId, erasedType = N) sym.sourceAliases = aliases val sig = sign.map(term(_)) val p = Param(flg, sym, sig, Modulefulness.ofSign(sig)(Mod in modifiers)) @@ -2384,7 +2388,7 @@ extends Importer: // We create a symbol specifically for `Param` for each variable to // avoid redundantly redeclaring symbols in `Scope` during code // generation, which triggers the assertion in `Scope.addToBindings`. - val parameterSymbol = VarSymbol(new Ident(symbol.name)) + val parameterSymbol = VarSymbol(new Ident(symbol.name), erasedType = N) (name -> parameterSymbol, symbol -> parameterSymbol) .toList.unzip pattern.variables.report // Report all invalid variables we found in `pattern`. @@ -2521,7 +2525,7 @@ extends Importer: case TyTup(ps) => val vs = ps.flatMap: case id: Ident => - val sym = VarSymbol(id) + val sym = VarSymbol(id, erasedType = N) sym.decl = S(TyParam(FldFlags.empty, N, sym)) Param(FldFlags.empty, sym, N, Modulefulness.none) :: Nil case t => diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala index e987655c66..63cb715879 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala @@ -30,7 +30,7 @@ class Importer: val nme = file.baseName val id = alias.getOrElse(new syntax.Tree.Ident(nme)) // TODO loc - lazy val sym = VarSymbol(id) + lazy val sym = VarSymbol(id, erasedType = N) if path.startsWith(".") || path.startsWith("/") then // leave alone imports like "fs" log(s"importing $file") @@ -59,7 +59,7 @@ class Importer: case Some(nme -> imsym) => imsym case None => lastWords(s"File $file does not define a symbol named $nme") val sym: VarSymbol | BlockMemberSymbol = alias.fold(importedSym): alias => - VarSymbol(alias) + VarSymbol(alias, erasedType = N) val jsFile = file.up / io.RelPath(file.baseName + ".mjs") Import(sym, jsFile.toString, jsFile) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala index 0c3321fa04..d537930483 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala @@ -53,7 +53,7 @@ object Pattern: // TODO: The above edge case would fail the following assertion. assert(symbols.size <= 1) // If no symbol had been created before, create a new symbol now. - val symbol = symbols.headOption.getOrElse(VarSymbol(Ident(name))) + val symbol = symbols.headOption.getOrElse(VarSymbol(Ident(name), erasedType = N)) aliases.foreach: alias => // For guarded patterns (`p where t`), the variables in `p` have to be // allocated before `t` is elaborated. In that case, we don't need to diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index b5cee7f206..26bac61396 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -10,6 +10,7 @@ import hkmc2.utils.* import Elaborator.State import Tree.Ident +import hkmc2.codegen.{ErasedType, HasErasedType, HasManyMutableErasedType, HasOnceMutableErasedType, PrimitiveType, erasedType} import hkmc2.utils.SymbolSubst @@ -224,12 +225,13 @@ class SplitSymbol(val body: Split, name: Str = "split")(using State) extends Loc def toLoc = body.toLoc override def prefix: Str = "split:" -sealed abstract class LocalVarSymbol(name: Str)(using State) extends FlowSymbol(name) with LocalSymbol: +sealed abstract class LocalVarSymbol(name: Str)(using State) extends FlowSymbol(name) with LocalSymbol with HasErasedType: self: LocalSymbol => // * using `with LocalSymbol` in the `extends` clause makes Scala think there's a bad override var decl: Opt[Declaration] = N def subst(using s: SymbolSubst): LocalVarSymbol -class TempSymbol(val trm: Opt[Term], dbgNme: Str = "tmp")(using State) extends LocalVarSymbol(dbgNme): +class TempSymbol(val trm: Opt[Term], override val erasedType: Opt[ErasedType], dbgNme: Str = "tmp")(using State) + extends LocalVarSymbol(dbgNme): // val nameHints: MutSet[Str] = MutSet.empty // * May be useful later? override def toLoc: Option[Loc] = trm.flatMap(_.toLoc) override def prefix: Str = "tmp:" @@ -245,7 +247,10 @@ class InstSymbol(val origin: Symbol)(using State) extends LocalSymbol: def subst(using sub: SymbolSubst): InstSymbol = sub.mapInstSym(this) -class VarSymbol(val id: Ident)(using State) extends LocalVarSymbol(id.name) with NamedSymbol: +class VarSymbol(val id: Ident, override var erasedType: Opt[ErasedType])(using State) + extends LocalVarSymbol(id.name) + with HasManyMutableErasedType + with NamedSymbol: val name: Str = id.name var sourceAliases: Ls[Str] = Nil override def toLoc: Opt[Loc] = id.toLoc @@ -253,8 +258,8 @@ class VarSymbol(val id: Ident)(using State) extends LocalVarSymbol(id.name) with override def subst(using s: SymbolSubst): VarSymbol = s.mapVarSym(this) class BuiltinSymbol - (val nme: Str, val binary: Bool, val unary: Bool, val nullary: Bool, val functionLike: Bool)(using State) - extends Symbol: + (val nme: Str, val binary: Bool, val unary: Bool, val nullary: Bool, val functionLike: Bool, override val erasedType: Opt[ErasedType])(using State) + extends Symbol with HasErasedType: def toLoc: Option[Loc] = N override def prefix: Str = "builtin:" @@ -280,6 +285,35 @@ class BuiltinSymbol case _ => Bot semantics.flow.Producer.Typ(typ) + /** The result [[`ErasedType`]] of applying this builtin operator to operands of the given erased types. + * + * Returns `N` if the operator is not recognized or the arguments to the operator is not sufficient to determine the + * result type. + */ + def resultErasedType(args: Ls[Opt[ErasedType]]): Opt[ErasedType] = + import ErasedType.Primitive + def isStr(t: Opt[ErasedType]) = t.contains(Primitive(PrimitiveType.Str)) + def isInt(t: Opt[ErasedType]) = t.contains(Primitive(PrimitiveType.Int)) + def isNum(t: Opt[ErasedType]) = isInt(t) || t.contains(Primitive(PrimitiveType.Num)) + nme match + case "==" | "!=" | "<" | "<=" | ">" | ">=" | "===" | "!==" | "&&" | "||" | "!" => + S(Primitive(PrimitiveType.Bool)) + case "typeof" => S(Primitive(PrimitiveType.Str)) + // * `+` is overloaded (numeric add / string concat): only commit when the operands + // * decide it, otherwise leave it unknown (an unknown operand could be a `Str`). + case "+" => + if args.exists(isStr) then S(Primitive(PrimitiveType.Str)) + else if args.forall(isInt) then S(Primitive(PrimitiveType.Int)) + else if args.forall(isNum) then S(Primitive(PrimitiveType.Num)) + else N + // * The remaining arithmetic operators are numeric-only, so the result is always a + // * number; it is an `Int` only when every operand is known to be an `Int`. + case "-" | "*" | "%" => + if args.forall(isInt) then S(Primitive(PrimitiveType.Int)) else S(Primitive(PrimitiveType.Num)) + case "/" => S(Primitive(PrimitiveType.Num)) + case "~" => S(Primitive(PrimitiveType.Int)) + case _ => N + /** This is the outside-facing symbol associated to a possibly-overloaded * definition living in a block – e.g., a module or class. @@ -336,9 +370,10 @@ sealed abstract class MemberSymbol(using State) extends Symbol: def subst(using SymbolSubst): MemberSymbol -class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.Ident)(using State) +class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.Ident, override var erasedType: Opt[ErasedType])(using State) extends MemberSymbol with DefinitionSymbol[TermDefinition] + with HasOnceMutableErasedType with NamedSymbol: var sourceAliases: Ls[Str] = Nil def nme: Str = id.name @@ -362,8 +397,8 @@ class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.I defn.forall(_.mayRaiseEffects) object TermSymbol: - def fromFunBms(b: BlockMemberSymbol, owner: Opt[InnerSymbol])(using State) = - TermSymbol(syntax.Fun, owner, Tree.Ident(b.nme)) + def fromFunBms(b: BlockMemberSymbol, owner: Opt[InnerSymbol], erasedType: Opt[ErasedType])(using State) = + TermSymbol(syntax.Fun, owner, Tree.Ident(b.nme), erasedType) /** Represents the companion constructor function of parameterized classes, @@ -373,7 +408,7 @@ class ClassCtorSymbol( override val k: syntax.Fun.type, override val owner: Opt[InnerSymbol], val associatedCls: ClassSymbol, -)(using State) extends TermSymbol(k, owner, associatedCls.id): +)(using State) extends TermSymbol(k, owner, associatedCls.id, N): override def subst(using sub: SymbolSubst): ClassCtorSymbol = sub.mapClassCtorSym(this) override def mayRaiseEffects(using Config) = super.mayRaiseEffects || config.checkInstantiateEffect @@ -388,7 +423,8 @@ case class Extr(isTop: Bool)(using State) extends CtorSymbol: def toLoc: Option[Loc] = N override def toString: Str = nme -sealed abstract case class LitSymbol(lit: Literal)(using State) extends CtorSymbol: +sealed abstract case class LitSymbol(lit: Literal)(using State) extends CtorSymbol, HasErasedType: + override val erasedType: Opt[ErasedType] = S(lit.erasedType) def nme: Str = lit.idStr def toLoc: Option[Loc] = lit.toLoc override def prefix: Str = "lit:" @@ -413,7 +449,7 @@ case class ErrorSymbol(val nme: Str, tree: Tree)(using State) extends MemberSymb override def subst(using sub: SymbolSubst): ErrorSymbol = sub.mapErrorSym(this) override def prefix: Str = "error:" -sealed trait ClassLikeSymbol extends IdentifiedSymbol: +sealed trait ClassLikeSymbol extends IdentifiedSymbol, HasErasedType: self: MemberSymbol & DefinitionSymbol[? <: ClassDef | ModuleOrObjectDef] => val tree: Tree.TypeDef def subst(using sub: SymbolSubst): ClassLikeSymbol @@ -476,7 +512,8 @@ sealed trait InnerSymbol(using State) extends Symbol: // ensure that any implementation of InnerSymbol is also a DefinitionSymbol. self: DefinitionSymbol[? <: ClassLikeDef] => val privatesScope: Scope = Scope.empty(Scope.Cfg.default) // * Scope for private members of this symbol - val thisProxy: TempSymbol = TempSymbol(N, s"this$$$nme") + // TODO(Derppening): Can we meaningfully infer the erased type of `this` from the definition? + val thisProxy: TempSymbol = TempSymbol(N, erasedType = N, s"this$$$nme") def subst(using SymbolSubst): InnerSymbol def asDefnSym: DefinitionSymbol[? <: ClassLikeDef] & InnerSymbol = this match case d: DefinitionSymbol[? <: ClassLikeDef] => d @@ -492,6 +529,8 @@ class ClassSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using State) with InnerSymbol with NamedSymbol: + override val erasedType: Opt[ErasedType] = S(ErasedType.AnyRef(rsc = false, this)) + def name: Str = nme def nme = id.name def toLoc: Option[Loc] = id.toLoc // TODO track source tree of classe here @@ -508,6 +547,9 @@ class ModuleOrObjectSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using Sta with DefinitionSymbol[ModuleOrObjectDef] with InnerSymbol with NamedSymbol: + + override val erasedType: Opt[ErasedType] = S(ErasedType.AnyRef(rsc = false, this)) + def name: Str = nme def nme = id.name def toLoc: Option[Loc] = id.toLoc // TODO track source tree of module here @@ -519,7 +561,11 @@ class ModuleOrObjectSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using Sta class TypeAliasSymbol(val id: Tree.Ident)(using State) extends MemberSymbol - with DefinitionSymbol[TypeDef]: + with DefinitionSymbol[TypeDef] + with HasErasedType: + + override val erasedType: Opt[ErasedType] = irClsLikeDefn.flatMap(_.sym.asClsOrMod).map(ErasedType.AnyRef(rsc = false, _)) + def nme = id.name def toLoc: Option[Loc] = id.toLoc // TODO track source tree of type alias here override def prefix: Str = "type:" diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala index 8f8fa55b79..9a8312e650 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala @@ -605,6 +605,7 @@ extension (self: Blk) case class ShowCfg( + showErasedTypes: Bool, showExpansionMappings: Bool, showFlowSymbols: Bool, debug: Bool, @@ -616,6 +617,7 @@ end ShowCfg object ShowCfg: // * For use when displaying things for internal use (not for end users) val internal = ShowCfg( + showErasedTypes = true, showFlowSymbols = true, showExpansionMappings = false, debug = false, diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index 119b474b05..eed9fa09e1 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -400,7 +400,8 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C // the Label body into the rest. Wrap with an exit label and temp variable so every path stores its // result, breaks to exitLabel, then the original cont runs once. val exitLabel = new LabelSymbol(N, sym.nme + "$x") - val tmp = new TempSymbol(N) + // TODO(Derppening): Change this to `r.erasedType` when `Result.erasedType` is implemented + val tmp = new TempSymbol(N, erasedType = N) LoweringCtx.loweringCtx.collectScopedSym(tmp) val exitCont: Result => Block = r => Assign(tmp, r, Break(exitLabel)) val bodyBlock = lowerSplit(sym.body, exitCont) @@ -455,7 +456,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C // 3. The term is a `while` and the result is used. lazy val l = usesResTmp = true - val res = new TempSymbol(t) + val res = new TempSymbol(t, erasedType = N) outerCtx.collectScopedSym(res) res // The symbol for the loop label if the term is a `while`. @@ -464,7 +465,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C val res = new BlockMemberSymbol("while", Nil, false) outerCtx.collectScopedSym(res) res - lazy val tSym = TermSymbol.fromFunBms(f, N) + lazy val tSym = TermSymbol.fromFunBms(f, N, erasedType = N) val normalized = tl.scoped("ucs:normalize"): normalize(inputSplit)(using VarSet()) tl.scoped("ucs:normalized"): @@ -473,7 +474,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C form match case IfLikeForm.ReturningIf => if (k is Ret) || (k is Thrw) then k(r) else Assign(l, r, End()) case IfLikeForm.ImperativeIf => Assign.discard(r, End()) - case IfLikeForm.While => Assign(State.noSymbol, r, loopCont) + case IfLikeForm.While => Assign(NoSymbol, r, loopCont) // NOTE: `shouldRewriteWhile` is not the same as `config.rewriteWhileLoops` // as shouldRewriteWhile is always true when effect handler lowering is on lazy val loopCont = if config.shouldRewriteWhile @@ -505,8 +506,8 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State, C // NOTE: `shouldRewriteWhile` is not the same as `config.rewriteWhileLoops` // as shouldRewriteWhile is always true when effect handler lowering is on if config.shouldRewriteWhile then - val loopResult = TempSymbol(N) - val isReturned = TempSymbol(N) + val loopResult = TempSymbol(N, erasedType = N) + val isReturned = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool))) outerCtx.collectScopedSym(loopResult) outerCtx.collectScopedSym(isReturned) val loopEnd: Path = diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/TermSynthesizer.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/TermSynthesizer.scala index 3018d79677..c97e5ebb71 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/TermSynthesizer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/TermSynthesizer.scala @@ -93,11 +93,11 @@ trait TermSynthesizer(using State): app(stringLeave, tup(fld(t), fld(int(n))), label) protected final def tempLet(dbgName: Str, term: Term)(inner: TempSymbol => Split): Split = - val s = TempSymbol(N, dbgName) + val s = TempSymbol(N, erasedType = N, dbgName) Split.Let(s, term, inner(s)) protected final def plainTest(cond: Term, dbgName: Str = "cond")(inner: => Split): Split = - val s = TempSymbol(N, dbgName) + val s = TempSymbol(N, erasedType = N, dbgName) Split.Let(s, cond, Branch(s.safeRef, inner) ~: Split.End) protected final def makeBindings(fields: Ls[RcdField | RcdSpread]): Term = diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/Compiler.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/Compiler.scala index 9e767fd48b..b110b4185c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/Compiler.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/Compiler.scala @@ -136,7 +136,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter // by the set of labels (orders are not important). val labels = patterns.map(_.label) multiMatchers.get(labels).getOrElse: - val f = TempSymbol(N, makeMultiMatcherName(patterns)) + val f = TempSymbol(N, erasedType = N, makeMultiMatcherName(patterns)) multiMatchers += (labels -> f) buildQueue enqueue (f -> patterns) f // Return the symbol of the built function. @@ -152,7 +152,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val expandedPatterns = patterns.map(p => (p.label, p.expand(Set.empty))) val heads = expandedPatterns.flatMap((_, p) => p.heads).toList // This is the parameter of the current multi-matcher. - val scrutinee = VarSymbol(Ident("input")) + val scrutinee = VarSymbol(Ident("input"), erasedType = N) // Assemble branches for constructors and literals. val branches = heads.map: head => val specialized = expandedPatterns.specializeSet(S(head)) @@ -164,7 +164,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val empty: (Ls[(LocalVarSymbol, Opt[Loc])], Map[Ident | Int, LocalVarSymbol]) = (Nil, Map.empty) val (arguments, fields) = params.params.foldLeft(empty): (acc, param) => val (argAcc, fieldAcc) = acc - val fieldSymbol = TempSymbol(N, param.sym.nme) + val fieldSymbol = TempSymbol(N, erasedType = N, param.sym.nme) (argAcc :+ (fieldSymbol, param.toLoc), fieldAcc + ((param.sym.id: Ident | Int) -> fieldSymbol)) (S(arguments), fields) case _: (syntax.Literal | ModuleOrObjectSymbol) => empty @@ -195,7 +195,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val subPatternsByField = Map.from(fields.map: field => field -> patterns.flatMap((_, p) => p.collectSubPatterns(field))) val subScrutinees = Map.from(subPatternsByField.map: (field, subPatterns) => - field -> MatcherResult(VarSymbol(field.asIdent), subPatterns.map(_.label))) + field -> MatcherResult(VarSymbol(field.asIdent, erasedType = N), subPatterns.map(_.label))) // Let bindings that bind the sub-scrutinee to the result of each matcher. val bindings = subPatternsByField.iterator.flatMap: (field, subPatterns) => val subScrutinee = subScrutinees(field) @@ -211,7 +211,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter case N => // Check the presence of the field, and call the matcher if it exists. val fieldIdent: Ident = field.asIdent - val fieldSymbol = TempSymbol(N, fieldIdent.name) + val fieldSymbol = TempSymbol(N, erasedType = N, fieldIdent.name) val fieldTest = FlatPattern.Record((fieldIdent -> fieldSymbol) :: Nil) val consequent = Split.Else(makeResult(fieldSymbol)) val branch = Branch(scrutinee.safeRef, fieldTest, consequent) @@ -223,7 +223,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val z = (Nil: Ls[Statement], Nil: Ls[(Label, Term)]) val (tests, resultTerms) = patterns.iterator.foldLeft(z): case ((stmts, results), (label, pattern)) => - val symbol = TempSymbol(N, label.asFieldName + "$") + val symbol = TempSymbol(N, erasedType = N, label.asFieldName + "$") val makeSplit = completePattern(pattern, scrutinee, subScrutinees, Nil) val split = makeSplit( // There is no topmost transform here, so we emit the direct success @@ -278,7 +278,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter case N => Split.Else: makeMatchSuccess(output, nullifyEmptyBindings(bindings)) case S(transform) => - val resultSymbol = TempSymbol(N, "transformResult") + val resultSymbol = TempSymbol(N, erasedType = N, "transformResult") val bindingsTerm = nullifyEmptyBindings(bindings) val transformTerm = app(transform.safeRef, tup(fld(bindingsTerm)), "the transform's result") Split.Let(resultSymbol, transformTerm, Split.Else( @@ -297,7 +297,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter case ((field, pattern), makeInnerSplit) => val label = pattern.label val target = subScrutinees(field).select(label) - val resultSymbol = TempSymbol(N, s"result$label$$") + val resultSymbol = TempSymbol(N, erasedType = N, s"result$label$$") (makeConsequent, alternative) => Split.Let(resultSymbol, target, Branch(resultSymbol.safeRef, @@ -338,10 +338,10 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter case ((field, pattern), makeInnerSplit) => val label = pattern.label val target = subScrutinees(field).select(label) - val resultSymbol = TempSymbol(N, s"result$label$$") - val outputSymbol = TempSymbol(N, s"output$label$$") + val resultSymbol = TempSymbol(N, erasedType = N, s"result$label$$") + val outputSymbol = TempSymbol(N, erasedType = N, s"output$label$$") val fieldAliases = pattern.aliases - val fieldBindingsSymbol = TempSymbol(N, "fieldBindings") + val fieldBindingsSymbol = TempSymbol(N, erasedType = N, "fieldBindings") val fieldBindingsTerm = makeBindings(fieldAliases.map: alias => RcdField(str(alias.name), outputSymbol.safeRef)) val fieldOutput = @@ -349,7 +349,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter else subScrutinees(field).input (outputFields: Ls[(Ident, Term)], bindingsSymbols: Ls[TempSymbol]) => ((makeConsequent, alternative) => - val bindingsSymbol = TempSymbol(N, "bindings") + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings") val accumulatedBindings = val withFieldAliases = if fieldAliases.isEmpty then bindingsSymbols @@ -373,7 +373,7 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter case ((field, pattern), makeInnerSplit) => val label = pattern.label val target = subScrutinees(field).select(label) - val resultSymbol = TempSymbol(N, s"result$label$$") + val resultSymbol = TempSymbol(N, erasedType = N, s"result$label$$") (makeConsequent, alternative) => Split.Let(resultSymbol, target, Branch(resultSymbol.safeRef, @@ -412,18 +412,18 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val label = pattern.label val target = subScrutinees(field).select(label) // This is the symbol for `MatchSuccess`. - val resultSymbol = TempSymbol(N, s"result$label$$") + val resultSymbol = TempSymbol(N, erasedType = N, s"result$label$$") // This is the symbol for the output of the pattern. - val outputSymbol = TempSymbol(N, s"output$label$$") + val outputSymbol = TempSymbol(N, erasedType = N, s"output$label$$") val outputField = RcdField(str(field.name), outputSymbol.safeRef) // This is the bindings of the current field. val fieldAliases = pattern.aliases - val fieldBindingsSymbol = TempSymbol(N, "fieldBindings") + val fieldBindingsSymbol = TempSymbol(N, erasedType = N, "fieldBindings") val fieldBindingsTerm = makeBindings(fieldAliases.map: alias => RcdField(str(alias.name), outputSymbol.safeRef)) (outputFields: Ls[RcdField], bindingsSymbols: Ls[TempSymbol]) => ((makeConsequent, alternative) => - val bindingsSymbol = TempSymbol(N, "bindings") + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings") val accumulatedBindings = val withFieldAliases = if fieldAliases.isEmpty then bindingsSymbols @@ -480,9 +480,9 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter (makeConsequent, alternative) => // Here, we need to make a tuple of all output values of patterns. // Then we merge the bindings of all patterns. - val outputSymbol = TempSymbol(N, "combinedOutput") + val outputSymbol = TempSymbol(N, erasedType = N, "combinedOutput") val outputTerm = tup(allOutputs.reverseIterator.map(_.use |> fld).toSeq*) - val bindingsSymbol = TempSymbol(N, "combinedBindings") + val bindingsSymbol = TempSymbol(N, erasedType = N, "combinedBindings") // I think the bindings do not need to be reversed. val bindingsTerm = makeBindings(allBindings.map: binding => RcdSpread(binding.use)) @@ -519,9 +519,9 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter else // The symbol representing the transform function, which should be // declared at the outermost level. - val transformSymbol = TempSymbol(N, "transform") + val transformSymbol = TempSymbol(N, erasedType = N, "transform") // The transform function takes a single record as the argument. - val bindingsSymbol = VarSymbol(Ident("args")) + val bindingsSymbol = VarSymbol(Ident("args"), erasedType = N) val params = paramList(param(bindingsSymbol)) // Because we pass the extracted values using recoreds. We need to bind // each property to its corresponding variable which is accessible from @@ -542,10 +542,10 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter val transformTerm = app(transformSymbol.safeRef, tup(fld(bindings.use)), "the transform's result") // Bind the transformation result to a new output symbol. - val resultSymbol = TempSymbol(N, "transformResult") + val resultSymbol = TempSymbol(N, erasedType = N, "transformResult") // Don't forget that current pattern may also have aliases which are // available in some outer transform patterns. - val currentBindingsSymbol = TempSymbol(N, "bindings") + val currentBindingsSymbol = TempSymbol(N, erasedType = N, "bindings") val currentBindings = makeBindings(aliases.map: alias => RcdField(str(alias.name), resultSymbol.safeRef)) Split.Let(resultSymbol, transformTerm, diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/FixedPointCompiler.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/FixedPointCompiler.scala index 1776a410f2..0739e667b1 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/FixedPointCompiler.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/FixedPointCompiler.scala @@ -11,6 +11,7 @@ import ucs.{FlatPattern, TermSynthesizer, warn, safeRef} import semantics.Pattern as SP import Pattern.*, Context.* import Compiler.ResultMode +import codegen.{ErasedType, PrimitiveType} object FixedPointCompiler: /** One alternative of the context pattern that descends into the hole. @@ -499,19 +500,19 @@ class FixedPointCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynt postCompiler.buildMatcher(postPattern, ResultMode.Full) (stepMatcher, stepImpls, postMatcherOpt) - val inputSymbol = VarSymbol(Ident("input")) - val phaseSymbol = TempSymbol(N, "phase") - val focusSymbol = TempSymbol(N, "focus") + val inputSymbol = VarSymbol(Ident("input"), erasedType = N) + val phaseSymbol = TempSymbol(N, erasedType = N, "phase") + val focusSymbol = TempSymbol(N, erasedType = N, "focus") // The phase whose step failed, selecting the link whose alternatives // process the final term. Since the loop only ever exits through a step // failure, it is always set before `result` runs. - val failPhaseSymbol = TempSymbol(N, "failPhase") + val failPhaseSymbol = TempSymbol(N, erasedType = N, "failPhase") val size = links.size val loop = matchers.iterator.zipWithIndex.foldRight(Split.End: Split): case (((stepMatcher, _, _), index), rest) => - val resultSym = TempSymbol(N, s"step$index$$Result") - val outputSym = TempSymbol(N, "stepOutput") + val resultSym = TempSymbol(N, erasedType = N, s"step$index$$Result") + val outputSym = TempSymbol(N, erasedType = N, "stepOutput") Branch(phaseSymbol.safeRef, intPattern(index), Split.Let(resultSym, callMatcher(stepMatcher, focusSymbol.safeRef, "step result"), Branch(resultSym.safeRef, matchSuccessPattern(S(outputSym :: Nil)), @@ -577,15 +578,15 @@ class FixedPointCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynt log(s"Classes: ${classes.map(cls => s"${cls.symbol.nme}(${cls.alts})").mkString(", ")}") // ---- Machine state ---- - val inputSymbol = VarSymbol(Ident("input")) - val modeSymbol = TempSymbol(N, "mode") - val focusSymbol = TempSymbol(N, "focus") - val stackSymbol = TempSymbol(N, "stack") - val resultSymbol = TempSymbol(N, "finalResult") + val inputSymbol = VarSymbol(Ident("input"), erasedType = N) + val modeSymbol = TempSymbol(N, erasedType = N, "mode") + val focusSymbol = TempSymbol(N, erasedType = N, "focus") + val stackSymbol = TempSymbol(N, erasedType = N, "stack") + val resultSymbol = TempSymbol(N, erasedType = N, "finalResult") // Whether at least one contraction has fired. Only the one-or-more shape // (`P as (S | _)`) tracks it: it requires the first step to succeed, so a // run with zero contractions is a match failure. - val progressedSymbol = TempSymbol(N, "progressed") + val progressedSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "progressed") def bool(value: Bool): Term = Term.Lit(BoolLit(value)) def markProgress: Ls[Statement] = @@ -620,14 +621,14 @@ class FixedPointCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynt sides match case Nil => success case (index, side) :: rest => - val okSymbol = TempSymbol(N, "sideOk") + val okSymbol = TempSymbol(N, erasedType = N, "sideOk") Split.Let(okSymbol, callMatcher(sideMatchers(side), child(index), "side condition"), Branch(okSymbol.safeRef, sideChecks(rest, child, success, failure)) ~: failure()) /** Call the redex matcher on `scrutinee` and branch on its result. */ def matchRedex(scrutinee: Term, onSuccess: TempSymbol => Split, onFailure: Split): Split = - val resultSym = TempSymbol(N, "redexResult") - val outputSym = TempSymbol(N, "contractum") + val resultSym = TempSymbol(N, erasedType = N, "redexResult") + val outputSym = TempSymbol(N, erasedType = N, "contractum") Split.Let(resultSym, callMatcher(redexMatcher, scrutinee, "redex match"), Branch(resultSym.safeRef, matchSuccessPattern(S(outputSym :: Nil)), onSuccess(outputSym)) ~: onFailure) @@ -646,7 +647,7 @@ class FixedPointCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynt sideChecks(alt.sides, index => children(index).safeRef, push, failure) val findSplit = val classChain = classes.foldRight(goUp()): (cls, rest) => - val children = List.tabulate(cls.paramCount)(index => TempSymbol(N, s"scrut$index")) + val children = List.tabulate(cls.paramCount)(index => TempSymbol(N, erasedType = N, s"scrut$index")) Branch(focusSymbol.safeRef, classPattern(cls, children), findDescend(cls, children)) ~: rest matchRedex(focusSymbol.safeRef, // A contraction: refocus on the contractum and keep searching. This @@ -666,7 +667,7 @@ class FixedPointCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynt if index == hole then focusSymbol.safeRef else children(index).safeRef def inert(index: Int): Term = if index == hole then bool(true) else inerts(index).safeRef - val rebuiltSymbol = TempSymbol(N, "rebuilt") + val rebuiltSymbol = TempSymbol(N, erasedType = N, "rebuilt") val pop = perform( setStmt(stackSymbol, tailSym.safeRef), setStmt(focusSymbol, rebuiltSymbol.safeRef)) @@ -685,7 +686,7 @@ class FixedPointCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynt // Plugging preserves identity: when no contraction happened below the // frame, the focus is still the very child we descended into, and the // frame's node can be reused instead of allocating a rebuilt copy. - val unchangedSymbol = TempSymbol(N, "unchanged") + val unchangedSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "unchanged") Split.Let(unchangedSymbol, refEq(focusSymbol.safeRef, children(hole).safeRef), Split.Let(rebuiltSymbol, Term.SynthIf( @@ -698,14 +699,14 @@ class FixedPointCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynt setStmt(modeSymbol, int(ModeFind)) :: markProgress)*), altChain))) def upClass(cls: ClassInfo): Split = - val children = List.tabulate(cls.paramCount)(index => TempSymbol(N, s"frameChild$index")) - val inerts = List.tabulate(cls.paramCount)(index => TempSymbol(N, s"frameInert$index")) - val nodeSym = TempSymbol(N, "frameNode") - val tailSym = TempSymbol(N, "frameTail") + val children = List.tabulate(cls.paramCount)(index => TempSymbol(N, erasedType = N, s"frameChild$index")) + val inerts = List.tabulate(cls.paramCount)(index => TempSymbol(N, erasedType = N, s"frameInert$index")) + val nodeSym = TempSymbol(N, erasedType = N, "frameNode") + val tailSym = TempSymbol(N, erasedType = N, "frameTail") val core = cls.alts.map(_.holeIndex).distinct match case only :: Nil => upHole(cls, only, children, inerts, nodeSym, tailSym) case multiple => - val holeSymbol = TempSymbol(N, "frameHole") + val holeSymbol = TempSymbol(N, erasedType = N, "frameHole") Split.Let(holeSymbol, sel(stackSymbol.safeRef, "h"), multiple.init.foldRight(upHole(cls, multiple.last, children, inerts, nodeSym, tailSym)): (hole, rest) => Branch(holeSymbol.safeRef, intPattern(hole), @@ -727,7 +728,7 @@ class FixedPointCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynt case Nil => Split.End // No frames are ever pushed. case only :: Nil => upClass(only) case multiple => - val tagSymbol = TempSymbol(N, "frameTag") + val tagSymbol = TempSymbol(N, erasedType = N, "frameTag") Split.Let(tagSymbol, sel(stackSymbol.safeRef, "tag"), multiple.init.foldRight(upClass(multiple.last)): (cls, rest) => Branch(tagSymbol.safeRef, intPattern(cls.index), upClass(cls)) ~: rest) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala index 0e5465894f..26fb7aac7d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ups/SplitCompiler.scala @@ -7,6 +7,7 @@ import Message.MessageContext import ucs.{TermSynthesizer, FlatPattern, error, warn, safeRef}, ucs.extractors.* import syntax.{Fun, Keyword, Tree}, Tree.{Ident, StrLit}, Keyword.{`as`, `=>`} import collection.mutable.{Buffer, HashMap}, collection.immutable.SeqMap +import codegen.{ErasedType, PrimitiveType} import Elaborator.{Ctx, State, ctx}, utils.TL import semantics.Pattern as SP // "SP" is short for "semantic patterns" import Term.Ref @@ -28,14 +29,14 @@ object SplitCompiler: def getSubScrutinee(cs: ClassSymbol | PatternSymbol)(i: Int)(using State): SymbolScrut[?] = val scrutinees = subScrutinees.getOrElseUpdate(cs, Buffer.empty) while scrutinees.size <= i do - scrutinees += TempSymbol(N, s"arg$$${cs.nme}$$${scrutinees.size}$$").toScrut + scrutinees += TempSymbol(N, erasedType = N, s"arg$$${cs.nme}$$${scrutinees.size}$$").toScrut scrutinees(i) def getTupleLeadSubScrutinee(index: Int)(using State): SymbolScrut[?] = - tupleLead.getOrElseUpdate(index, TempSymbol(N, s"element$index$$").toScrut) + tupleLead.getOrElseUpdate(index, TempSymbol(N, erasedType = N, s"element$index$$").toScrut) def getTupleLastSubScrutinee(index: Int)(using State): SymbolScrut[?] = - tupleLast.getOrElseUpdate(index, TempSymbol(N, s"lastElement$index$$").toScrut) + tupleLast.getOrElseUpdate(index, TempSymbol(N, erasedType = N, s"lastElement$index$$").toScrut) def getFieldScrutinee(fieldName: Ident)(using State): SymbolScrut[?] = - fields.getOrElseUpdate(fieldName, TempSymbol(N, s"field_${fieldName.name}$$").toScrut) + fields.getOrElseUpdate(fieldName, TempSymbol(N, erasedType = N, s"field_${fieldName.name}$$").toScrut) object Scrut: def from(ref: Term.Ref): RefScrut = RefScrut(() => ref) @@ -109,7 +110,7 @@ object SplitCompiler: private var _hasBeenUsed = false private lazy val symbol = _hasBeenUsed = true - TempSymbol(N, nameHint.getOrElse("output")) + TempSymbol(N, erasedType = N, nameHint.getOrElse("output")) def apply(): Ref = symbol.safeRef def toList: Ls[TempSymbol] = if _hasBeenUsed then symbol :: Nil else Nil def toLet(term: => Term, tail: Split): Split = @@ -312,7 +313,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz def makePatternBindings(using State): Ls[(TempSymbol, SP)] = patterns.iterator.zipWithIndex.map: - case (pattern, index) => new TempSymbol(N, s"patternArgument${index}$$") -> pattern + case (pattern, index) => new TempSymbol(N, erasedType = N, s"patternArgument${index}$$") -> pattern .toList /** @@ -580,11 +581,11 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz // Then, we can call the `unapply` method. val unapplyArguments = patternBindings.map(_._1.safeRef |> fld) :+ fld(scrutinee()) val unapplyCall = app(sel(patternTerm, "unapply").resolve, tup(unapplyArguments*), s"result of unapply") - val unapplyResult = TempSymbol(N, "unapplyResult") + val unapplyResult = TempSymbol(N, erasedType = N, "unapplyResult") val wrapUnapply = Split.Let(unapplyResult, unapplyCall, _) // Then, we need to destruct the produced `MatchSuccess`. - val outputSymbol = TempSymbol(N, "output").toScrut // TODO: We can use `LazyScrut` for this, but the match result pattern's parameters requires a symbol. - val bindingsSymbol = TempSymbol(N, "bindings") // TODO: This can be automatically removed when no transformation is used. + val outputSymbol = TempSymbol(N, erasedType = N, "output").toScrut // TODO: We can use `LazyScrut` for this, but the match result pattern's parameters requires a symbol. + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings") // TODO: This can be automatically removed when no transformation is used. val wrapDestruction = (consequent: Split) => val pattern = matchSuccessPattern(S(outputSymbol.symbol :: bindingsSymbol :: Nil)) Branch(unapplyResult.safeRef, pattern, consequent) ~: alternative @@ -632,8 +633,8 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val unapplyTerm = sel(parameterTerm, "unapply").resolve val unapplyCall = app(unapplyTerm, tup(fld(scrutinee())), s"result of unapply") tempLet(s"matchSuccess_${parameterSymbol.name}", unapplyCall): matchSuccessSymbol => - val outputSymbol = TempSymbol(N, "output").toScrut - val bindingsSymbol = TempSymbol(N, "bindings").toScrut + val outputSymbol = TempSymbol(N, erasedType = N, "output").toScrut + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings").toScrut val pattern = matchSuccessPattern(S(Ls(outputSymbol.symbol, bindingsSymbol.symbol))) val consequent = makeConsequent(outputSymbol, SeqMap.empty) Branch(matchSuccessSymbol.safeRef, pattern, consequent) ~: alternative @@ -730,7 +731,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val trailingSize = trailing.size val (trailSubScrutinees, makeConsequent0) = trailing.folded(makeConsequent): index => scrutinee.getTupleLastSubScrutinee(trailingSize - index) - val spreadSubScrutinee = TempSymbol(N, "middleElements") + val spreadSubScrutinee = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "middleElements") val makeConsequent1: MakeConsequent = (outerOutput, outerBindings) => makeMatchSplit(spreadSubScrutinee.toScrut, spread, false)( (spreadOutput, spreadBindings) => makeConsequent0( @@ -793,7 +794,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val params = parameters.map: case (_, parameterSymbol) => Param(FldFlags.empty, parameterSymbol, N, Modulefulness.none) - val lambdaSymbol = new TempSymbol(N, "transform") + val lambdaSymbol = new TempSymbol(N, erasedType = N, "transform") // Next, we need to elaborate the pattern into a split. Note that // `makeMatchSplit` returns a function that takes a split as the // consequence. `makeMatchSplit` also takes a list of symbols so that @@ -811,7 +812,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz (_output, bindings) => val arguments = symbols.iterator.map(bindings).map(_() |> fld).toSeq val resultTerm = app(lambdaSymbol.safeRef, tup(arguments*), "the transform's result") - val resultSymbol = TempSymbol(N, "transformResult") + val resultSymbol = TempSymbol(N, erasedType = N, "transformResult") Split.Let(resultSymbol, resultTerm, makeConsequent(resultSymbol.toScrut, SeqMap.empty)), alternative)) case Annotated(pattern, annotations) => @@ -837,7 +838,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz case Guarded(pattern, guard) => (makeConsequent, alternative) => makeMatchSplit(scrutinee, pattern, true)( (output, bindings) => - val guardSymbol = TempSymbol(N, "guardResult") + val guardSymbol = TempSymbol(N, erasedType = N, "guardResult") val branch = Branch(guardSymbol.ref(), makeConsequent(output, bindings)) val innermost = Split.Let(guardSymbol, guard, branch ~: Split.End) // The creation of bindings here is repeated with the creation of @@ -858,8 +859,8 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz matchParametersWithArguments(defn, arguments) warnOnDiscardedExtractionOutputs(patternSymbol, extractionMatches) if shouldReject then RejectPrefixSplit else (makeConsequent, alternative) => - val outputSymbol = TempSymbol(N, "output").toScrut - val remainingSymbol = TempSymbol(N, "remaining").toScrut + val outputSymbol = TempSymbol(N, erasedType = N, "output").toScrut + val remainingSymbol = TempSymbol(N, erasedType = N, "remaining").toScrut val (extractionArguments, makeConsequentForArguments) = extractionMatches.fold((N, makeConsequent)): _.folded(makeConsequent)(scrutinee.getSubScrutinee(patternSymbol)).mapFirst(S(_)) @@ -870,8 +871,8 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val args = tup(patternBindings.map(_._1.safeRef) :+ scrutinee()) app(sel(patternTerm, method), args, s"result of $method") val split = tempLet("unapplyResult", unapplyCall): matchSuccessSymbol => - val outputPairSymbol = TempSymbol(N, "outputPair") - val bindingsSymbol = TempSymbol(N, "bindings") + val outputPairSymbol = TempSymbol(N, erasedType = N, "outputPair") + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings") Branch( matchSuccessSymbol.safeRef, matchSuccessPattern(S(outputPairSymbol :: bindingsSymbol :: Nil)), @@ -896,12 +897,12 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val unapplyCall = app(unapplyTerm, tup(fld(scrutinee())), s"result of unapply") tempLet(s"matchSuccess_${parameterSymbol.name}", unapplyCall): matchSuccessSymbol => // Destruct the `MatchSuccess` produced by the `unapply` method. - val outputPairSymbol = TempSymbol(N, "outputPair").toScrut - val bindingsSymbol = TempSymbol(N, "bindings").toScrut + val outputPairSymbol = TempSymbol(N, erasedType = N, "outputPair").toScrut + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings").toScrut val pattern = matchSuccessPattern(S(Ls(outputPairSymbol.symbol, bindingsSymbol.symbol))) // Destruct the first field of the `MatchSuccess` as a pair. - val outputSymbol = TempSymbol(N, "output").toScrut // Denotes the pattern's output. - val remainingSymbol = TempSymbol(N, "remaining").toScrut // Denotes the remaining value. + val outputSymbol = TempSymbol(N, erasedType = N, "output").toScrut // Denotes the pattern's output. + val remainingSymbol = TempSymbol(N, erasedType = N, "remaining").toScrut // Denotes the remaining value. // Assemble the `Split`s in order from inside to outside. val consequent1 = makeConsequent(outputSymbol, remainingSymbol, SeqMap.empty) val consequent2 = makeTupleBranch(outputPairSymbol(), @@ -933,15 +934,15 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz ) RejectPrefixSplit case _ => (makeConsequent, alternative) => - val nonEmptySymbol = TempSymbol(N, "nonEmpty") + val nonEmptySymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "nonEmpty") val nonEmptyTerm = app( this.lt.safeRef, tup(fld(int(0)), fld(sel(scrutinee(), "length"))), "string is not empty" ) - val outputSymbol = TempSymbol(N, "stringHead") + val outputSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "stringHead") val outputTerm = callStringGet(scrutinee(), 0, "head") - val remainsSymbol = TempSymbol(N, "stringTail") + val remainsSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "stringTail") val remainsTerm = callStringDrop(scrutinee(), 1, "tail") Split.Let(nonEmptySymbol, nonEmptyTerm, Branch(nonEmptySymbol.safeRef, @@ -965,7 +966,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz makeStringPrefixMatchSplit(scrutinee, left)( (leftOutput, leftRemains, leftBindings) => makeMatchSplit(scrutinee, right)( (rightOutput, rightBindings) => - val productSymbol = TempSymbol(N, "product") + val productSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Array)), "product") val productTerm = tup(leftOutput() |> fld, rightOutput() |> fld) Split.Let(productSymbol, productTerm, makeConsequent( productSymbol.toScrut, leftRemains, leftBindings ++ rightBindings)), @@ -984,7 +985,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz case Wildcard() => (makeConsequent, alternative) => // Because the wildcard pattern always matches, we can match the entire // string and returns an empty string as the remaining value. - val emptyStringSymbol = TempSymbol(N, "emptyString") + val emptyStringSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "emptyString") makeConsequent(scrutinee, emptyStringSymbol.toScrut, SeqMap.empty) Branch(scrutinee(), FlatPattern.ClassLike(ctx.builtins.Str.safeRef, ctx.builtins.Str, N, false)(Tree.Dummy), Split.Let(emptyStringSymbol, str(""), @@ -993,12 +994,12 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz case Literal(prefix: StrLit) => (makeConsequent, alternative) => // Check if the scrutinee is the same as the literal. If so, we return // an empty string as the remaining value. - val isLeadingSymbol = TempSymbol(N, "isLeading") + val isLeadingSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "isLeading") val isLeadingTerm = callStringStartsWith( scrutinee(), Term.Lit(prefix), "the result of startsWith") - val outputSymbol = TempSymbol(N, "consumed") + val outputSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "consumed") val outputTerm = callStringTake(scrutinee(), prefix.value.length, "the consumed part of input") - val remainsSymbol = TempSymbol(N, "remains") + val remainsSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "remains") val remainsTerm = callStringDrop(scrutinee(), prefix.value.length, "the remaining input") Split.Let(isLeadingSymbol, isLeadingTerm, Branch(isLeadingSymbol.safeRef, @@ -1010,9 +1011,9 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz case Literal(_) => RejectPrefixSplit case Range(lower: StrLit, upper: StrLit, rightInclusive) => (makeConsequent, alternative) => // Check if the string is not empty. Then - val stringHeadSymbol = TempSymbol(N, "stringHead") - val stringTailSymbol = TempSymbol(N, "stringTail") - val nonEmptySymbol = TempSymbol(N, "nonEmpty") + val stringHeadSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "stringHead") + val stringTailSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Str)), "stringTail") + val nonEmptySymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "nonEmpty") val nonEmptyTerm = app(this.lt.safeRef, tup(fld(int(0)), fld(sel(scrutinee(), "length"))), "string is not empty") Split.Let(nonEmptySymbol, nonEmptyTerm, // `0 < string.length` Branch(nonEmptySymbol.safeRef, @@ -1071,7 +1072,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val params = parameters.map: case (_, parameterSymbol) => Param(FldFlags.empty, parameterSymbol, N, Modulefulness.none) - val lambdaSymbol = new TempSymbol(N, "transform") + val lambdaSymbol = new TempSymbol(N, erasedType = N, "transform") (makeConsequent, alternative) => Split.Let( sym = lambdaSymbol, term = Term.Lam(PlainParamList(params), transform), @@ -1081,7 +1082,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz (_output, remains, bindings) => val arguments = symbols.iterator.map(bindings).map(_() |> fld).toSeq val resultTerm = app(lambdaSymbol.safeRef, tup(arguments*), "the transform's result") - val resultSymbol = TempSymbol(N, "transformResult") + val resultSymbol = TempSymbol(N, erasedType = N, "transformResult") Split.Let(resultSymbol, resultTerm, makeConsequent(resultSymbol.toScrut, remains, SeqMap.empty)), alternative)) case Guarded(pattern, guard) => @@ -1089,7 +1090,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz make.whenAccept: (makeConsequent, alternative) => make( (output, remains, bindings) => - val guardSymbol = TempSymbol(N, "guardResult") + val guardSymbol = TempSymbol(N, erasedType = N, "guardResult") val branch = Branch(guardSymbol.ref(), makeConsequent(output, remains, bindings)) val innermost = Split.Let(guardSymbol, guard, branch ~: Split.End) // The creation of bindings here is repeated with the creation of @@ -1158,13 +1159,13 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz outputNeeded: Bool, fallbackPattern: Opt[SP] ): MakeSplit = (makeConsequent, alternative) => - val matcherSymbol = TempSymbol(N, "fixedPointMatcher") + val matcherSymbol = TempSymbol(N, erasedType = N, "fixedPointMatcher") val matcherBody = Term.Blk( machine.prelude :+ Term.SynthWhile(machine.loop), machine.result) val callTerm = app(matcherSymbol.safeRef, tup(fld(scrutinee())), "fixed-point match result") Split.Let(matcherSymbol, Term.Lam(machine.params, matcherBody), tempLet("fixedPointResult", callTerm): resultSymbol => - val outputSymbol = TempSymbol(N, "output").toScrut + val outputSymbol = TempSymbol(N, erasedType = N, "output").toScrut val consequent = outputPattern match case N => makeConsequent(outputSymbol, SeqMap.empty) case S(subPattern) => @@ -1209,18 +1210,18 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz val (matcherSymbol, implementations) = compiler.buildMatcher(synonym, resultMode) val innermostSplit = resultMode match case ResultMode.MatchOnly => - val resultSymbol = TempSymbol(N, "matchSuccess") + val resultSymbol = TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Bool)), "matchSuccess") val resultTerm = app(matcherSymbol.safeRef, tup(fld(scrutinee())), "result of matcher function") Split.Let(resultSymbol, resultTerm, Branch(resultSymbol.safeRef, makeConsequent(scrutinee, SeqMap.empty)) ~: alternative) case ResultMode.Full => // 1. Bind the call result to a variable. - val recordSymbol = TempSymbol(N, "matchRecord") + val recordSymbol = TempSymbol(N, erasedType = N, "matchRecord") val recordTerm = app(matcherSymbol.safeRef, tup(fld(scrutinee())), "result of matcher function") val f1 = Split.Let(recordSymbol, recordTerm, _) // 2. Check if the direct result is a `MatchSuccess` and bind the output. - val outputSymbol = TempSymbol(N, "patternOutput") - val bindingsSymbol = TempSymbol(N, "bindings") // TODO: This is useless. + val outputSymbol = TempSymbol(N, erasedType = N, "patternOutput") + val bindingsSymbol = TempSymbol(N, erasedType = N, "bindings") // TODO: This is useless. val consequent = makeConsequent(outputSymbol.toScrut, SeqMap.empty) val pattern = matchSuccessPattern(S(outputSymbol :: bindingsSymbol :: Nil)) val branch = Branch(recordSymbol.safeRef, pattern, consequent) @@ -1279,7 +1280,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz s"compilePattern >>> $blk" ): val unapply = scoped("ucs:translation"): - val inputSymbol = VarSymbol(Ident("input")) + val inputSymbol = VarSymbol(Ident("input"), erasedType = N) val topmost = makeMatchSplit(inputSymbol.toScrut, pd.pattern, true)( makeConsequent = (output, bindings) => def getBinding(p: Param) = bindings.get(p.sym).fold(Term.Error().withLocOf(p))(_()) @@ -1306,7 +1307,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz // We don't report errors here because they have been already reported in // the translation of `unapply` function. given Raise = Function.const(()) - val inputSymbol = VarSymbol(Ident("input")) + val inputSymbol = VarSymbol(Ident("input"), erasedType = N) val topmost = makeStringPrefixMatchSplit(inputSymbol.toScrut, pd.pattern) ((consumedOutput, remainingOutput, bindings) => Split.Else: makeMatchSuccess(tup(fld(consumedOutput()), fld(remainingOutput()))), failure) @@ -1322,7 +1323,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz scrut: VarSymbol, topmost: Split ): Ls[Statement] = - val fieldSymbol = TempSymbol(N, name) + val fieldSymbol = TempSymbol(N, erasedType = N, name) val decl = LetDecl(fieldSymbol, Nil) val param = Param(FldFlags.empty, scrut, N, Modulefulness.none) val paramList = PlainParamList(param :: Nil) @@ -1343,7 +1344,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz case _ => N term.getOrElse: val unapply = scoped("ucs:translation"): - val inputSymbol = VarSymbol(Ident("input")) + val inputSymbol = VarSymbol(Ident("input"), erasedType = N) val topmost = makeMatchSplit(inputSymbol.toScrut, pattern, true) ((output, bindings) => Split.Else(makeMatchSuccess(output())), failure) log(s"Translated `unapply`: ${topmost.prettyPrint}") @@ -1352,7 +1353,7 @@ class SplitCompiler(using tl: TL)(using State, Ctx, Raise) extends TermSynthesiz // We don't report errors here because they have been already reported in // the translation of `unapply` function. given Raise = Function.const(()) - val inputSymbol = VarSymbol(Ident("input")) + val inputSymbol = VarSymbol(Ident("input"), erasedType = N) val topmost = makeStringPrefixMatchSplit(inputSymbol.toScrut, pattern) ((consumedOutput, remainingOutput, bindings) => Split.Else: makeMatchSuccess(tup(fld(consumedOutput()), fld(remainingOutput()))), failure) diff --git a/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala b/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala index 77e061a5cc..5e865afede 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala @@ -657,7 +657,7 @@ trait TypeDefImpl(using State) extends TypeOrTermDef: pts.flatMap(_.desugared.asParam(inUsing = inUsing).toOption).collect: case pt @ ParamTree(ident = id, spd = N) => val k = if pt.flags.mut then MutVal else ImmutVal - TermSymbol(k, symbol.asClsLike, id) + TermSymbol(k, symbol.asClsLike, id, erasedType = N) .toList lazy val allSymbols = definedSymbols ++ diff --git a/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls new file mode 100644 index 0000000000..86796fc115 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/codegen/ErasedType.mls @@ -0,0 +1,281 @@ +:sir + +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ import "..." as Predef; end + + +// Literals: Int / Str / Num / Bool. A local picks up its initializer's type. +:siret +let i = 1 +let f = 1.5 +let s = "str" +let b = true +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ set i⁰ = 1; set f⁰ = 1.5; set s⁰ = "str"; set b⁰ = true; end + + +// Local reassignment join: a local keeps its type while assignments agree, and +// widens to `Object` on a type conflict. A declared-then-assigned local (no +// initializer) is still typed from its assignment. +:siret +let consistent = 1 +set consistent = 2 +let conflicting = 1 +set conflicting = "two" +let declared +set declared = 3 +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ set consistent⁰ = 1; +//│ set consistent⁰ = 2; +//│ set conflicting⁰ = 1; +//│ set conflicting⁰ = "two"; +//│ set declared⁰ = 3; +//│ return runtime⁰.Unit⁰ + + +// Tuple -> Array; record -> Object (the `NoSymbol` top case). +:siret +let t = [i, s] +let r = {x: i, y: s} +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let x: Int, y: Str; set t⁰ = [i⁰, s⁰]; set x = i⁰; set y = s⁰; set r⁰ = { "x": x, "y": y }; end + + +// Instantiation -> the class symbol. +:siret +class Foo(x: Int) +let foo = new Foo(123) +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ define Foo⁰ as class Foo² { +//│ private val x⁰: Int; +//│ constructor Foo¹(x: Int) { set Foo².this.x⁰ = x; end } +//│ }; +//│ set foo⁰ = new Foo²(123); +//│ end + + +// Instantiation on class with `val` param -> the class symbol. +:siret +class Foo(val x: Str) +let foo = new Foo(123) +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ define Foo³ as class Foo⁵ { +//│ val x¹: Str; +//│ constructor Foo⁴(x: Str) { define x¹ as val x²: Str = x; end } +//│ }; +//│ set foo¹ = new Foo⁵(123); +//│ end + + +// Call return is inferred from the callee's return type; field selection stays +// `: ?` (resolved in a later phase). +:siret +fun mk(a) = new Foo(a) +let made = mk(1) +let sel = made.x +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ define mk⁰ as fun mk¹(a: ?): Foo⁵ { +//│ return new Foo⁵(a) +//│ }; +//│ set made⁰ = mk¹(1); +//│ set sel⁰ = made⁰.x﹖; +//│ end + + +// Parameter annotations. The return type is unannotated but inferred from the +// body: `Int + Int` yields `Int`. +:siret +fun add(x: Int, y: Int) = x + y +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ define add⁰ as fun add¹(x: Int, y: Int): Int { return +⁰(x, y) }; end + + +// A reassigned parameter joins with its annotation: a conflicting `set` widens +// the annotated type (and the inferred return) to `Object`; a consistent `set` +// keeps it. +:siret +fun widen(x: Int) = + set x = "s" + x +fun keep(x: Int) = + set x = 5 + x +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ define widen⁰ as fun widen¹(x: Object): Object { +//│ set x = "s"; +//│ return x +//│ }; +//│ define keep⁰ as fun keep¹(x: Int): Int { set x = 5; return x }; +//│ end + + +// Return-type annotations are rendered at the `fun ...: T` definition site, +// for both primitive and class return types. +:siret +fun addUp(x: Int, y: Int): Int = x + y +fun makeFoo(n: Int): Foo = new Foo(n) +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ define addUp⁰ as fun addUp¹(x: Int, y: Int): Int { +//│ return +⁰(x, y) +//│ }; +//│ define makeFoo⁰ as fun makeFoo¹(n: Int): Foo⁵ { return new Foo⁵(n) }; +//│ end + + +// A rest parameter binds an array of the collected arguments, so it is typed +// `Array` regardless of any element annotation. +:siret +fun rest(a, ...xs) = xs +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ define rest⁰ as fun rest¹(a: ?, ...xs: Array): Array { return xs }; end + + +// Inferred returns from the body terminal: a builtin comparison yields `Bool` +// (regardless of operand types), an instantiation yields the class, and an +// unconstrained identity return stays `: ?`. (Equality `==` lowers to a +// `Predef.equals` member call, not a builtin operator, so it stays `: ?`.) +:siret +fun lt(a, b) = a < b +fun mkFoo(a) = new Foo(a) +fun ident(a) = a +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ define lt⁰ as fun lt¹(a: ?, b: ?): Bool { +//│ return <⁰(a, b) +//│ }; +//│ define mkFoo⁰ as fun mkFoo¹(a: ?): Foo⁵ { +//│ return new Foo⁵(a) +//│ }; +//│ define ident⁰ as fun ident¹(a: ?): ? { return a }; +//│ end + + +// Method return-type annotations are rendered the same way as free functions. +:siret +class Calc() with + fun twice(x: Int): Int = x + x + fun whatever(a) = a +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ define Calc⁰ as class Calc² { +//│ constructor Calc¹ +//│ method twice⁰ = fun twice¹(x: Int): Int { +//│ return +⁰(x, x) +//│ } +//│ method whatever⁰ = fun whatever¹(a: ?): ? { return a } +//│ }; +//│ end + + +// A selection resolved to a member (`this.v`) carries that member's type, so a +// field getter's return is inferred. +:siret +class Box(val v: Int) with + fun get() = v +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ define Box⁰ as class Box² { +//│ val v⁰: Int; +//│ constructor Box¹(v: Int) { +//│ define v⁰ as val v¹: Int = v; +//│ end +//│ } +//│ method get⁰ = fun get¹(): Int { return Box².this.v¹ } +//│ }; +//│ end + + +// Optimized IR with erased types: surfaces where a type relaxes to `: ?` +// through a transform (the relaxation gap `:soir` makes visible). +:soir +:siret +fun id(a) = a +id(new Foo(1)) +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let tmp: Foo⁵; define id⁰ as fun id¹(a: ?): ? { return a }; set tmp = new Foo⁵(1); return id¹(tmp) +//│ ——————————————| Optimized IR |—————————————————————————————————————————————————————————————————————— +//│ define id⁰ as fun id¹(a: ?): ? { return a }; return new Foo⁵(1) + + +// Closure capture (`:lift`): the symbol holding the capture object is typed as +// the generated capture class. The capture field inherits the captured local's +// erased type — but here `n` is reassigned from `n + 1`, whose operand type is +// not yet known when the (hoisted) closure body is lowered, so the local (and +// hence the field) widens to `?`. +:siret +:lift +fun counter() = + let n = 0 + fun bump() = set n = n + 1 + bump() + n +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let bump, Capture$scope0; +//│ define Capture$scope0 as class Capture$scope0¹ { +//│ constructor Capture$scope0⁰(n$0: ?) { +//│ define n$0⁰ as val n$0¹: ? = n$0; +//│ end +//│ } +//│ }; +//│ define bump as fun bump⁰(scope0$cap: ?): Unit⁰ { +//│ let tmp: ?; +//│ set tmp = +⁰(scope0$cap.n$0¹, 1); +//│ set scope0$cap.n$0¹ = tmp; +//│ return runtime⁰.Unit⁰ +//│ }; +//│ define counter⁰ as fun counter¹(): ? { +//│ let n: ?, scope0$cap: Capture$scope0¹; +//│ set scope0$cap = new mut Capture$scope0¹(n); +//│ set scope0$cap.n$0¹ = 0; +//│ do bump⁰(scope0$cap); +//│ return scope0$cap.n$0¹ +//│ }; +//│ end + + +// When the captured local keeps a definite erased type (here `n` is only ever +// assigned `Int` literals), the generated capture field inherits it: both the +// constructor parameter `n$0` and the `val n$0` field are typed `Int`. +:siret +:lift +fun resetCounter() = + let n = 0 + fun reset() = set n = 100 + reset() + n +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let reset, Capture$scope0; +//│ define Capture$scope0 as class Capture$scope0³ { +//│ constructor Capture$scope0²(n$0: Int) { +//│ define n$0² as val n$0³: Int = n$0; +//│ end +//│ } +//│ }; +//│ define reset as fun reset⁰(scope0$cap: ?): Unit⁰ { +//│ set scope0$cap.n$0³ = 100; +//│ return runtime⁰.Unit⁰ +//│ }; +//│ define resetCounter⁰ as fun resetCounter¹(): Int { +//│ let n: Int, scope0$cap: Capture$scope0³; +//│ set scope0$cap = new mut Capture$scope0³(n); +//│ set scope0$cap.n$0³ = 0; +//│ do reset⁰(scope0$cap); +//│ return scope0$cap.n$0³ +//│ }; +//│ end + + +// A read-only capture is lifted into a by-value parameter rather than a capture +// field; that parameter inherits the captured local's erased type (`k: Int`). +// (`rd`'s return stays `?`: the hoisted closure body is lowered before `k` is +// seeded, so the return-type inference cannot yet see `k`'s type.) +:siret +:lift +fun reader() = + let k = 0 + fun rd() = k + rd() +//│ ———————————————| Lowered IR |——————————————————————————————————————————————————————————————————————— +//│ let rd; +//│ define rd as fun rd⁰(k: Int): ? { +//│ return k +//│ }; +//│ define reader⁰ as fun reader¹(): ? { let k: Int; set k = 0; return rd⁰(k) }; +//│ end diff --git a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls index 1c653b6b9a..685f701941 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Basics.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Basics.mls @@ -85,6 +85,57 @@ foo(41) //│ = 42 +// Tests whether cast elision and insertion works correctly: +// - `addOne(41)` in the top-level elides a cast since `41: i31` and `(ref i31) <: (ref null i31)` +// - `addOne(v)` in `useIt` inserts a cast since `v: anyref` and `(ref null any) :> (ref null i31)` +:noInline +:wat +fun addOne(x: Int) = x + 1 +fun useIt(v) = addOne(v) +addOne(41) + useIt(41) +//│ Wat: +//│ (module +//│ (type $TypeInfoBase (sub (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) +//│ (type $Object (sub (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) +//│ (type $plus_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) +//│ (type $addOne (func (param $x (ref null i31)) (result (ref null any)))) +//│ (type $useIt (func (param $v (ref null any)) (result (ref null any)))) +//│ (type $entry (func (result (ref null any)))) +//│ (import "system" "plus_impl" (func $plus_impl (type $plus_impl))) +//│ (func $addOne (export "addOne") (type $addOne) (param $x (ref null i31)) (result (ref null any)) +//│ (return +//│ (call $plus_impl +//│ (local.get $x) +//│ (ref.i31 +//│ (i32.const 1))))) +//│ (func $useIt (export "useIt") (type $useIt) (param $v (ref null any)) (result (ref null any)) +//│ (return +//│ (call $addOne +//│ (ref.cast (ref null i31) +//│ (local.get $v))))) +//│ (func $entry (export "entry") (type $entry) (result (ref null any)) +//│ (local $tmp (ref null any)) +//│ (local $tmp1 (ref null any)) +//│ (block (result (ref null any)) +//│ (local.set $tmp +//│ (call $addOne +//│ (ref.i31 +//│ (i32.const 41)))) +//│ (local.set $tmp1 +//│ (call $useIt +//│ (ref.i31 +//│ (i32.const 41)))) +//│ (return +//│ (call $plus_impl +//│ (local.get $tmp) +//│ (local.get $tmp1))))) +//│ (elem $addOne declare func $addOne) +//│ (elem $useIt declare func $useIt) +//│ (elem $entry declare func $entry)) +//│ Wasm result: +//│ = 84 + + class Foo(val x) new Foo(0) //│ Wasm result: @@ -101,7 +152,7 @@ class Foo(val a) //│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $a (mut (ref null any)))))) //│ (type $Foo_init (func (param $this (ref null any)) (param $a (ref null any)) (result (ref null any)))) -//│ (type $Foo_ctor (func (param $a (ref null any)) (result (ref null any)))) +//│ (type $Foo_ctor (func (param $a (ref null any)) (result (ref $Foo)))) //│ (type $entry (func (result (ref null any)))) //│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo //│ (i32.const 1) @@ -114,9 +165,9 @@ class Foo(val a) //│ (local.get $a)) //│ (return //│ (local.get $this)))) -//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $a (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $a (ref null any)) (result (ref $Foo)) +//│ (local $this (ref $Foo)) +//│ (block (result (ref $Foo)) //│ (local.set $this //│ (struct.new $Foo //│ (global.get $Foo_typeinfo) @@ -128,7 +179,7 @@ class Foo(val a) //│ (return //│ (local.get $this)))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) -//│ (local $tmp (ref null any)) +//│ (local $tmp (ref $Foo)) //│ (block (result (ref null any)) //│ (local.set $tmp //│ (call $Foo_ctor @@ -136,8 +187,7 @@ class Foo(val a) //│ (i32.const 42)))) //│ (return //│ (struct.get $Foo $a -//│ (ref.cast (ref $Foo) -//│ (local.get $tmp)))))) +//│ (local.get $tmp))))) //│ (elem $Foo_init declare func $Foo_init) //│ (elem $Foo_ctor declare func $Foo_ctor) //│ (elem $entry declare func $entry)) @@ -151,6 +201,69 @@ getX(Foo(42)) //│ = 42 +// Checks that `f` in `getX` is typed as `(ref $Foo)` (evident by the type annotation) +:noInline +:wat +class Foo(val x) +fun getX(f: Foo) = f.x +getX(Foo(42)) +//│ Wat: +//│ (module +//│ (type $TypeInfoBase (sub (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) +//│ (type $Object (sub (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) +//│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) +//│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any)))))) +//│ (type $Foo_init (func (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)))) +//│ (type $Foo_ctor (func (param $x (ref null any)) (result (ref $Foo)))) +//│ (type $getX (func (param $f (ref $Foo)) (result (ref null any)))) +//│ (type $entry (func (result (ref null any)))) +//│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo +//│ (i32.const 1) +//│ (ref.null $TypeInfoBase))) +//│ (func $Foo_init (type $Foo_init) (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)) +//│ (block (result (ref null any)) +//│ (struct.set $Foo $x +//│ (ref.cast (ref $Foo) +//│ (local.get $this)) +//│ (local.get $x)) +//│ (return +//│ (local.get $this)))) +//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $x (ref null any)) (result (ref $Foo)) +//│ (local $this (ref $Foo)) +//│ (block (result (ref $Foo)) +//│ (local.set $this +//│ (struct.new $Foo +//│ (global.get $Foo_typeinfo) +//│ (ref.null any))) +//│ (drop +//│ (call $Foo_init +//│ (local.get $this) +//│ (local.get $x))) +//│ (return +//│ (local.get $this)))) +//│ (func $getX (export "getX") (type $getX) (param $f (ref $Foo)) (result (ref null any)) +//│ (return +//│ (struct.get $Foo $x +//│ (local.get $f)))) +//│ (func $entry (export "entry") (type $entry) (result (ref null any)) +//│ (local $tmp (ref null any)) +//│ (block (result (ref null any)) +//│ (local.set $tmp +//│ (call $Foo_ctor +//│ (ref.i31 +//│ (i32.const 42)))) +//│ (return +//│ (call $getX +//│ (ref.cast (ref $Foo) +//│ (local.get $tmp)))))) +//│ (elem $Foo_init declare func $Foo_init) +//│ (elem $Foo_ctor declare func $Foo_ctor) +//│ (elem $getX declare func $getX) +//│ (elem $entry declare func $entry)) +//│ Wasm result: +//│ = 42 + + :wat class Foo(val x) with val y = this.x @@ -162,7 +275,7 @@ class Foo(val x) with //│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any))) (field $y (mut (ref null any)))))) //│ (type $Foo_init (func (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)))) -//│ (type $Foo_ctor (func (param $x (ref null any)) (result (ref null any)))) +//│ (type $Foo_ctor (func (param $x (ref null any)) (result (ref $Foo)))) //│ (type $entry (func (result (ref null any)))) //│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo //│ (i32.const 1) @@ -181,9 +294,9 @@ class Foo(val x) with //│ (local.get $this)))) //│ (return //│ (local.get $this)))) -//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $x (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $x (ref null any)) (result (ref $Foo)) +//│ (local $this (ref $Foo)) +//│ (block (result (ref $Foo)) //│ (local.set $this //│ (struct.new $Foo //│ (global.get $Foo_typeinfo) @@ -196,7 +309,7 @@ class Foo(val x) with //│ (return //│ (local.get $this)))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) -//│ (local $tmp (ref null any)) +//│ (local $tmp (ref $Foo)) //│ (block (result (ref null any)) //│ (local.set $tmp //│ (call $Foo_ctor @@ -204,8 +317,7 @@ class Foo(val x) with //│ (i32.const 42)))) //│ (return //│ (struct.get $Foo $y -//│ (ref.cast (ref $Foo) -//│ (local.get $tmp)))))) +//│ (local.get $tmp))))) //│ (elem $Foo_init declare func $Foo_init) //│ (elem $Foo_ctor declare func $Foo_ctor) //│ (elem $entry declare func $entry)) @@ -225,7 +337,7 @@ O.y //│ (type $O_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $O (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any))) (field $y (mut (ref null any)))))) //│ (type $O_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $O_ctor (func (result (ref null any)))) +//│ (type $O_ctor (func (result (ref $O)))) //│ (type $entry (func (result (ref null any)))) //│ (type $start (func)) //│ (global $O_typeinfo (export "O_typeinfo") (ref $O_typeinfo) (struct.new $O_typeinfo @@ -247,9 +359,9 @@ O.y //│ (local.get $this)))) //│ (return //│ (local.get $this)))) -//│ (func $O_ctor (type $O_ctor) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $O_ctor (type $O_ctor) (result (ref $O)) +//│ (local $this (ref $O)) +//│ (block (result (ref $O)) //│ (local.set $this //│ (struct.new $O //│ (global.get $O_typeinfo) @@ -262,8 +374,7 @@ O.y //│ (local.get $this)))) //│ (func $start (type $start) //│ (global.set $O$inst -//│ (ref.cast (ref null $O) -//│ (call $O_ctor)))) +//│ (call $O_ctor))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) //│ (return //│ (struct.get $O $y @@ -283,5 +394,5 @@ fun bar() = 42 fun foo() = bar foo()() //│ ╔══[COMPILATION ERROR] Returning function instances is not supported -//│ ║ l.283: fun foo() = bar +//│ ║ l.394: fun foo() = bar //│ ╙── ^^^ diff --git a/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls b/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls index f39d387751..a00975aa8e 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ClassInheritance.mls @@ -31,11 +31,11 @@ c.Parent#x + (c is Parent) //│ (type $Parent_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Parent (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any)))))) //│ (type $Parent_init (func (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)))) -//│ (type $Parent_ctor (func (param $x (ref null any)) (result (ref null any)))) +//│ (type $Parent_ctor (func (param $x (ref null any)) (result (ref $Parent)))) //│ (type $Child_typeinfo (sub $Parent_typeinfo (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Child (sub $Parent (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any))) (field $y (mut (ref null any)))))) //│ (type $Child_init (func (param $this (ref null any)) (param $y (ref null any)) (result (ref null any)))) -//│ (type $Child_ctor (func (param $y (ref null any)) (result (ref null any)))) +//│ (type $Child_ctor (func (param $y (ref null any)) (result (ref $Child)))) //│ (type $plus_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) //│ (type $entry (func (result (ref null any)))) //│ (import "system" "plus_impl" (func $plus_impl (type $plus_impl))) @@ -54,9 +54,9 @@ c.Parent#x + (c is Parent) //│ (local.get $x)) //│ (return //│ (local.get $this)))) -//│ (func $Parent_ctor (export "Parent") (type $Parent_ctor) (param $x (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Parent_ctor (export "Parent") (type $Parent_ctor) (param $x (ref null any)) (result (ref $Parent)) +//│ (local $this (ref $Parent)) +//│ (block (result (ref $Parent)) //│ (local.set $this //│ (struct.new $Parent //│ (global.get $Parent_typeinfo) @@ -85,9 +85,9 @@ c.Parent#x + (c is Parent) //│ (local.get $y)) //│ (return //│ (local.get $this)))) -//│ (func $Child_ctor (export "Child") (type $Child_ctor) (param $y (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Child_ctor (export "Child") (type $Child_ctor) (param $y (ref null any)) (result (ref $Child)) +//│ (local $this (ref $Child)) +//│ (block (result (ref $Child)) //│ (local.set $this //│ (struct.new $Child //│ (global.get $Child_typeinfo) @@ -103,7 +103,7 @@ c.Parent#x + (c is Parent) //│ (local $matchRes (ref null any)) //│ (local $currentTypeInfo (ref null any)) //│ (local $targetTypeInfo (ref null any)) -//│ (local $typeInfoMatch (ref null any)) +//│ (local $typeInfoMatch (ref null i31)) //│ (block (result (ref null any)) //│ (global.set $c //│ (call $Child_ctor diff --git a/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls b/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls index facca545f7..b960180970 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ClassMethods.mls @@ -15,7 +15,7 @@ a.A#get() + A(2).get() //│ (type $A_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $A (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any)))))) //│ (type $A_init (func (param $this (ref null any)) (param $x (ref null any)) (result (ref null any)))) -//│ (type $A_ctor (func (param $x (ref null any)) (result (ref null any)))) +//│ (type $A_ctor (func (param $x (ref null any)) (result (ref $A)))) //│ (type $A_get (func (param $this (ref null any)) (result (ref null any)))) //│ (type $plus_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) //│ (type $entry (func (result (ref null any)))) @@ -32,9 +32,9 @@ a.A#get() + A(2).get() //│ (local.get $x)) //│ (return //│ (local.get $this)))) -//│ (func $A_ctor (export "A") (type $A_ctor) (param $x (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $A_ctor (export "A") (type $A_ctor) (param $x (ref null any)) (result (ref $A)) +//│ (local $this (ref $A)) +//│ (block (result (ref $A)) //│ (local.set $this //│ (struct.new $A //│ (global.get $A_typeinfo) diff --git a/hkmc2/shared/src/test/mlscript/wasm/Matching.mls b/hkmc2/shared/src/test/mlscript/wasm/Matching.mls index f3b8ec74ea..0c5a5a1c9a 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Matching.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Matching.mls @@ -48,11 +48,11 @@ if Bar(true) is //│ (type $Bar_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Bar (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $y (mut (ref null any)))))) //│ (type $Bar_init (func (param $this (ref null any)) (param $y (ref null any)) (result (ref null any)))) -//│ (type $Bar_ctor (func (param $y (ref null any)) (result (ref null any)))) +//│ (type $Bar_ctor (func (param $y (ref null any)) (result (ref $Bar)))) //│ (type $Baz_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Baz (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $z (mut (ref null any)))))) //│ (type $Baz_init (func (param $this (ref null any)) (param $z (ref null any)) (result (ref null any)))) -//│ (type $Baz_ctor (func (param $z (ref null any)) (result (ref null any)))) +//│ (type $Baz_ctor (func (param $z (ref null any)) (result (ref $Baz)))) //│ (type $entry (func (result (ref null any)))) //│ (global $Bar_typeinfo (export "Bar_typeinfo") (ref $Bar_typeinfo) (struct.new $Bar_typeinfo //│ (i32.const 1) @@ -68,9 +68,9 @@ if Bar(true) is //│ (local.get $y)) //│ (return //│ (local.get $this)))) -//│ (func $Bar_ctor (export "Bar") (type $Bar_ctor) (param $y (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Bar_ctor (export "Bar") (type $Bar_ctor) (param $y (ref null any)) (result (ref $Bar)) +//│ (local $this (ref $Bar)) +//│ (block (result (ref $Bar)) //│ (local.set $this //│ (struct.new $Bar //│ (global.get $Bar_typeinfo) @@ -89,9 +89,9 @@ if Bar(true) is //│ (local.get $z)) //│ (return //│ (local.get $this)))) -//│ (func $Baz_ctor (export "Baz") (type $Baz_ctor) (param $z (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Baz_ctor (export "Baz") (type $Baz_ctor) (param $z (ref null any)) (result (ref $Baz)) +//│ (local $this (ref $Baz)) +//│ (block (result (ref $Baz)) //│ (local.set $this //│ (struct.new $Baz //│ (global.get $Baz_typeinfo) @@ -108,11 +108,11 @@ if Bar(true) is //│ (local $matchRes (ref null any)) //│ (local $currentTypeInfo (ref null any)) //│ (local $targetTypeInfo (ref null any)) -//│ (local $typeInfoMatch (ref null any)) +//│ (local $typeInfoMatch (ref null i31)) //│ (local $matchRes1 (ref null any)) //│ (local $currentTypeInfo1 (ref null any)) //│ (local $targetTypeInfo1 (ref null any)) -//│ (local $typeInfoMatch1 (ref null any)) +//│ (local $typeInfoMatch1 (ref null i31)) //│ (block (result (ref null any)) //│ (local.set $scrut //│ (call $Bar_ctor diff --git a/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls b/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls index 3333b25518..4def17a2d3 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/ScopedLocals.mls @@ -149,7 +149,7 @@ class Foo(val a, val b) //│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $a (mut (ref null any))) (field $b (mut (ref null any)))))) //│ (type $Foo_init (func (param $this (ref null any)) (param $a (ref null any)) (param $b (ref null any)) (result (ref null any)))) -//│ (type $Foo_ctor (func (param $a (ref null any)) (param $b (ref null any)) (result (ref null any)))) +//│ (type $Foo_ctor (func (param $a (ref null any)) (param $b (ref null any)) (result (ref $Foo)))) //│ (type $entry (func (result (ref null any)))) //│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo //│ (i32.const 1) @@ -166,9 +166,9 @@ class Foo(val a, val b) //│ (local.get $b)) //│ (return //│ (local.get $this)))) -//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $a (ref null any)) (param $b (ref null any)) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Foo_ctor (export "Foo") (type $Foo_ctor) (param $a (ref null any)) (param $b (ref null any)) (result (ref $Foo)) +//│ (local $this (ref $Foo)) +//│ (block (result (ref $Foo)) //│ (local.set $this //│ (struct.new $Foo //│ (global.get $Foo_typeinfo) @@ -182,7 +182,7 @@ class Foo(val a, val b) //│ (return //│ (local.get $this)))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) -//│ (local $tmp (ref null any)) +//│ (local $tmp (ref $Foo)) //│ (block (result (ref null any)) //│ (local.set $tmp //│ (call $Foo_ctor @@ -192,8 +192,7 @@ class Foo(val a, val b) //│ (i32.const 1)))) //│ (return //│ (struct.get $Foo $a -//│ (ref.cast (ref $Foo) -//│ (local.get $tmp)))))) +//│ (local.get $tmp))))) //│ (elem $Foo_init declare func $Foo_init) //│ (elem $Foo_ctor declare func $Foo_ctor) //│ (elem $entry declare func $entry)) diff --git a/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls b/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls index adb7c68bcf..51b3f24893 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/SingletonUnit.mls @@ -11,7 +11,7 @@ //│ (type $Unit_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Unit (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) //│ (type $Unit_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $Unit_ctor (func (result (ref null any)))) +//│ (type $Unit_ctor (func (result (ref $Unit)))) //│ (type $entry (func (result (ref null any)))) //│ (type $start (func)) //│ (global $Unit_typeinfo (export "Unit_typeinfo") (ref $Unit_typeinfo) (struct.new $Unit_typeinfo @@ -21,9 +21,9 @@ //│ (func $Unit_init (type $Unit_init) (param $this (ref null any)) (result (ref null any)) //│ (return //│ (local.get $this))) -//│ (func $Unit_Unit (type $Unit_ctor) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Unit_Unit (type $Unit_ctor) (result (ref $Unit)) +//│ (local $this (ref $Unit)) +//│ (block (result (ref $Unit)) //│ (local.set $this //│ (struct.new $Unit //│ (global.get $Unit_typeinfo))) @@ -34,8 +34,7 @@ //│ (local.get $this)))) //│ (func $start (type $start) //│ (global.set $Unit$inst -//│ (ref.cast (ref null $Unit) -//│ (call $Unit_Unit)))) +//│ (call $Unit_Unit))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) //│ (return //│ (global.get $Unit$inst))) diff --git a/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls b/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls index 2bd903c1b3..d0ca4f6efa 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/Singletons.mls @@ -16,11 +16,11 @@ Bar.y //│ (type $Foo_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Foo (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $x (mut (ref null any)))))) //│ (type $Foo_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $Foo_ctor (func (result (ref null any)))) +//│ (type $Foo_ctor (func (result (ref $Foo)))) //│ (type $Bar_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase))))) //│ (type $Bar (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase))) (field $y (mut (ref null any)))))) //│ (type $Bar_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $Bar_ctor (func (result (ref null any)))) +//│ (type $Bar_ctor (func (result (ref $Bar)))) //│ (type $entry (func (result (ref null any)))) //│ (type $start (func)) //│ (global $Foo_typeinfo (export "Foo_typeinfo") (ref $Foo_typeinfo) (struct.new $Foo_typeinfo @@ -40,9 +40,9 @@ Bar.y //│ (i32.const 1))) //│ (return //│ (local.get $this)))) -//│ (func $Foo_ctor (type $Foo_ctor) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Foo_ctor (type $Foo_ctor) (result (ref $Foo)) +//│ (local $this (ref $Foo)) +//│ (block (result (ref $Foo)) //│ (local.set $this //│ (struct.new $Foo //│ (global.get $Foo_typeinfo) @@ -61,9 +61,9 @@ Bar.y //│ (i32.const 2))) //│ (return //│ (local.get $this)))) -//│ (func $Bar_ctor (type $Bar_ctor) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $Bar_ctor (type $Bar_ctor) (result (ref $Bar)) +//│ (local $this (ref $Bar)) +//│ (block (result (ref $Bar)) //│ (local.set $this //│ (struct.new $Bar //│ (global.get $Bar_typeinfo) @@ -76,11 +76,9 @@ Bar.y //│ (func $start (type $start) //│ (block //│ (global.set $Foo$inst -//│ (ref.cast (ref null $Foo) -//│ (call $Foo_ctor))) +//│ (call $Foo_ctor)) //│ (global.set $Bar$inst -//│ (ref.cast (ref null $Bar) -//│ (call $Bar_ctor))))) +//│ (call $Bar_ctor)))) //│ (func $entry (export "entry") (type $entry) (result (ref null any)) //│ (return //│ (struct.get $Bar $y diff --git a/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls b/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls index aafe528991..e497d1c179 100644 --- a/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls +++ b/hkmc2/shared/src/test/mlscript/wasm/VirtualMethods.mls @@ -28,12 +28,12 @@ callF(B()) //│ (type $A_typeinfo (sub $TypeInfoBase (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase)) (field $slot0 (mut (ref null $virtual2)))))) //│ (type $A (sub $Object (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) //│ (type $A_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $A_ctor (func (result (ref null any)))) +//│ (type $A_ctor (func (result (ref $A)))) //│ (type $B_typeinfo (sub $A_typeinfo (struct (field $$tag i32) (field $$parent (ref null $TypeInfoBase)) (field $slot0 (mut (ref null $virtual2)))))) //│ (type $B (sub $A (struct (field $$typeinfo (mut (ref $TypeInfoBase)))))) //│ (type $B_init (func (param $this (ref null any)) (result (ref null any)))) -//│ (type $B_ctor (func (result (ref null any)))) -//│ (type $callF (func (param $a (ref null any)) (result (ref null any)))) +//│ (type $B_ctor (func (result (ref $B)))) +//│ (type $callF (func (param $a (ref $A)) (result (ref null any)))) //│ (type $plus_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) //│ (type $times_impl (func (param $lhs (ref null any)) (param $rhs (ref null any)) (result (ref null any)))) //│ (type $entry (func (result (ref null any)))) @@ -50,9 +50,9 @@ callF(B()) //│ (func $A_init (type $A_init) (param $this (ref null any)) (result (ref null any)) //│ (return //│ (local.get $this))) -//│ (func $A_ctor (export "A") (type $A_ctor) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $A_ctor (export "A") (type $A_ctor) (result (ref $A)) +//│ (local $this (ref $A)) +//│ (block (result (ref $A)) //│ (local.set $this //│ (struct.new $A //│ (global.get $A_typeinfo))) @@ -74,9 +74,9 @@ callF(B()) //│ (local.get $this))) //│ (return //│ (local.get $this)))) -//│ (func $B_ctor (export "B") (type $B_ctor) (result (ref null any)) -//│ (local $this (ref null any)) -//│ (block (result (ref null any)) +//│ (func $B_ctor (export "B") (type $B_ctor) (result (ref $B)) +//│ (local $this (ref $B)) +//│ (block (result (ref $B)) //│ (local.set $this //│ (struct.new $B //│ (global.get $B_typeinfo))) @@ -91,7 +91,7 @@ callF(B()) //│ (local.get $x) //│ (ref.i31 //│ (i32.const 2))))) -//│ (func $callF (export "callF") (type $callF) (param $a (ref null any)) (result (ref null any)) +//│ (func $callF (export "callF") (type $callF) (param $a (ref $A)) (result (ref null any)) //│ (local $receiver (ref null any)) //│ (return //│ (block (result (ref null any)) diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala index 70aba1a3bf..fb8499dec4 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala @@ -105,6 +105,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: lazy val blockPrinter = given ShowCfg = ShowCfg( + showErasedTypes = showIRErasedTypes.isSet, showExpansionMappings = false, showFlowSymbols = true, debug = debug.isSet, @@ -151,6 +152,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: output(prog.showAsTree) if showIR.isSet || showIRLines.isSet then given ShowCfg = ShowCfg( + showErasedTypes = showIRErasedTypes.isSet, showExpansionMappings = false, showFlowSymbols = true, debug = debug.isSet, @@ -170,6 +172,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: if showOptimizedIR.isSet then outputSeparator("Optimized IR") given ShowCfg = ShowCfg( + showErasedTypes = showIRErasedTypes.isSet, showExpansionMappings = false, showFlowSymbols = true, debug = debug.isSet, @@ -201,7 +204,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val exportedScoped = symbolsToPreserve.collect: case sym: ScopedSymbol if !importedSymbols.contains(sym) => sym - val resSym = new TempSymbol(N, "block$res") + val resSym = new TempSymbol(N, erasedType = N, "block$res") val resNme = nestedScp.allocateName(resSym) @@ -277,7 +280,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val le = import codegen.* Assign( - Elaborator.State.noSymbol, + NoSymbol, Call( Elaborator.State.runtimeSymbol.asSimpleRef.selSN("printRaw"), (Arg(N, sym.asPath) :: Nil) ne_:: Nil)(CallMetadata.defaultMlsFun), diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala index 94d57698fe..c04ea4c8ec 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala @@ -47,6 +47,10 @@ abstract class MLsDiffMaker extends DiffMaker: val showOptimizedTree = NullaryCommand("olot") val debugOptimizations = NullaryCommand("dopt") val noOptimizations = NullaryCommand("noOpt") + val showIRErasedTypes = NullaryCommand("siret", () => + if showIR.isUnset && showOptimizedIR.isUnset then + output("Option ':siret' only has an effect if ':sir' or ':soir' is also set") + ) val showContext = NullaryCommand("ctx") val parseOnly = NullaryCommand("parseOnly") val funcToCls = NullaryCommand("ftc") @@ -469,6 +473,7 @@ abstract class MLsDiffMaker extends DiffMaker: if showFlows.isSet then import semantics.ShowCfg given ShowCfg = ShowCfg( + showErasedTypes = showIRErasedTypes.isSet, showExpansionMappings = true, showFlowSymbols = true, debug = debug.isSet,