Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
5dfa1b8
Add `PrimitiveType`, `ErasedType`, and refactor `Value.Lit`
Derppening May 27, 2026
5dab8dd
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening May 28, 2026
e3509d7
codegen: Add ErasedType for Value.This
Derppening May 28, 2026
f0fe27c
codegen: Remove `Value.Lit.erasedType`
Derppening May 28, 2026
df6d065
codegen: Add `MemberRef` to erasedType
Derppening May 28, 2026
fbe4104
codegen: Add `erasedType` to `SimpleRef`
Derppening May 29, 2026
d389627
WIP: Remove `SimpleRef.erasedType`
Derppening May 29, 2026
0379ae9
codegen: Implement `erasedType` in symbols
Derppening May 30, 2026
5c358ee
codegen: Add `ErasedType.ObjectRef`
Derppening May 30, 2026
f3bfba1
codegen: Bubble `HasErasedType` up to `Result`
Derppening May 30, 2026
7112f54
codegen: Tighten the erased type of some symbols
Derppening Jun 1, 2026
6dc04ca
codegen/js: Fix comment alignment
Derppening Jun 1, 2026
ab03528
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening Jun 1, 2026
1e0799c
semantics: Drop `using State` from `NoSymbol`
Derppening Jun 1, 2026
2d6d283
codegen: Fold `ObjectRef` as `AnyRef(_, NoSymbol)`
Derppening Jun 1, 2026
8ba7fe1
codegen: Implement printing of erased types
Derppening Jun 1, 2026
0c271c9
codegen: Implement `HasRefinableErasedType`
Derppening Jun 1, 2026
1620a30
codegen: Add erased type refinement to parameters
Derppening Jun 1, 2026
fe4b63b
codegen: Implement erased type refinement for class fields
Derppening Jun 1, 2026
60d531b
difftest: Add more padding to test cases in `ErasedType.mls`
Derppening Jun 1, 2026
adfc743
[WIP] codegen/wasm: Implement `ErasedType.wasmType`
Derppening Jun 1, 2026
dd94b46
WIP: Fixup `refineErasedType` implementation and docstring
Derppening Jun 1, 2026
197acde
codegen: Update printer to print type annotations on `let`
Derppening Jun 1, 2026
e3ff6ef
codegen: Do not show annotations for types
Derppening Jun 2, 2026
f518f15
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening Jun 2, 2026
df3b6ae
[WIP] codegen/wasm: Restrict type of local `$this`
Derppening Jun 2, 2026
b33b146
Revert dd94b46
Derppening Jun 3, 2026
2838dad
Update binaryen.js
Derppening Jun 4, 2026
52f6eb8
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening Jun 4, 2026
178a339
WIP: Implement `FuncRef`, refine in `Lowering`
Derppening Jun 4, 2026
4203e00
WIP: Add refinement of return types and operator types
Derppening Jun 4, 2026
202ef1d
Lowering: Correctly guard population of function types
Derppening Jun 5, 2026
4ddf252
codegen: Implement erased type relaxation
Derppening Jun 5, 2026
4f63baf
codegen: Propagate call types
Derppening Jun 5, 2026
f3bca96
codegen: Propagate selections
Derppening Jun 5, 2026
72a334b
codegen: Infer rest params
Derppening Jun 5, 2026
5e49f52
codegen: Infer type of lifter capture symbols
Derppening Jun 5, 2026
ca6dd9d
Rename refine -> populate
Derppening Jun 5, 2026
eaa56a6
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening Jun 5, 2026
4427c92
Rewrite Scaladoc, add TODOs
Derppening Jun 5, 2026
68b8c9e
codegen: Split `HasMutableErasedType`
Derppening Jun 5, 2026
8c380bd
codegen: Make `N: Opt[ErasedType]` mean an unknown type
Derppening Jun 5, 2026
4085dff
codegen: Propagate erased types of capture fields
Derppening Jun 5, 2026
f03bba6
codegen: Propagate erased types of params-from-lifted locals
Derppening Jun 5, 2026
7656e2b
codegen: Expand `FuncRef` and remove top-level optionality
Derppening Jun 5, 2026
6aa62d6
semantics: Make `NoSymbol` a `case object`
Derppening Jun 5, 2026
8fdc420
semantics: Make `NoSymbol` a plain `object`
Derppening Jun 5, 2026
fc5367d
semantics: Refactor and deprecate `State.noSymbol`
Derppening Jun 5, 2026
098ef02
Fix more `State.noSymbol` deprecation warnings
Derppening Jun 5, 2026
590e9a6
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening Jun 6, 2026
59e902d
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening Jun 7, 2026
df99448
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening Jun 17, 2026
23d6aaf
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening Jun 30, 2026
52b2912
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening Jun 30, 2026
bcbc587
Merge remote-tracking branch 'upstream/hkmc2' into enhance/typed-ir
Derppening Jul 2, 2026
3665ba2
WIP: Implement lowering of `Int` -> `i31ref`
Derppening Jul 2, 2026
9155768
WIP: Implement lowering of classes to specific refs
Derppening Jul 2, 2026
47e729c
wasm: Return concrete ref-type from constructors
Derppening Jul 2, 2026
3697356
wasm: Further restrict types of locals
Derppening Jul 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 144 additions & 8 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import hkmc2.semantics.{Term => st}
import syntax.{Literal, Tree, SpreadKind, Keyword}
import semantics.*
import semantics.Term.*
import sem.Elaborator.State
import sem.Elaborator.{Ctx, State, ctx}
import sourcecode.{FileName, Line}


/* Important design notes.
Expand Down Expand Up @@ -45,7 +46,7 @@ case class Program(
type SimpleSymbol = LocalVarSymbol | BuiltinSymbol

/** Symbol that can be used as the left-hand side of an `Assign`. */
type Assignable = LocalVarSymbol | NoSymbol
type Assignable = LocalVarSymbol | NoSymbol.type

/** Symbols that `Scoped` introduces as block-local bindings.
* This deliberately excludes things like `TermSymbol`s, which never need to be scoped.
Expand Down Expand Up @@ -552,7 +553,7 @@ object HandleBlock:
N, sym, PlainParamList(Param(FldFlags.empty, handler.resumeSym, N, Modulefulness.none) :: Nil) :: Nil,
handler.body
)(N, annotations = Nil)
val rSym = TempSymbol(N, "suspendRes")
val rSym = TempSymbol(N, erasedType = N, "suspendRes")
FunDefn.withFreshSymbol(
S(cls),
handler.sym,
Expand All @@ -570,7 +571,7 @@ object HandleBlock:
N, Nil,
S(par), handlerMtds, Nil, Nil,
// Apparently, the lifter is not happy with any assignment in the preCtor...
Assign(State.noSymbol, Call(State.builtinOpsMap("super").asSimpleRef, args.map(_.asArg) ne_:: Nil)(CallMetadata.mlsFunWithEffect), End()),
Assign(NoSymbol, Call(State.builtinOpsMap("super").asSimpleRef, args.map(_.asArg) ne_:: Nil)(CallMetadata.mlsFunWithEffect), End()),
End(),
N,
N,
Expand Down Expand Up @@ -699,7 +700,7 @@ final case class FunDefn(

object FunDefn:
def withFreshSymbol(owner: Opt[InnerSymbol], sym: BlockMemberSymbol, params: Ls[ParamList], body: Block)(configOverride: Opt[Config], annotations: Ls[Annot])(using State) =
val tSym = TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme))
val tSym = TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme), erasedType = N)
sym.tsym = S(tSym)
FunDefn(owner, sym, tSym, params, body)(configOverride, annotations)

Expand All @@ -725,7 +726,7 @@ object ValDefn:
annotations: Ls[Annot],
)(using State)
: ValDefn =
ValDefn(tsym = TermSymbol(k, owner, Tree.Ident(sym.nme)), sym = sym, rhs = rhs)(configOverride, annotations)
ValDefn(tsym = TermSymbol(k, owner, Tree.Ident(sym.nme), erasedType = rhs.erasedType), sym, rhs)(configOverride, annotations)


/*
Expand Down Expand Up @@ -857,9 +858,121 @@ enum Case:
case Tup(_, _) => Set.empty
case Field(_, _) => Set.empty

/** A primitive type of the block IR. */
enum PrimitiveType:
case Unit, Int, Int31, Num, Str, Bool, Array

/** The symbol for this primitive type. */
def sym(using Ctx, State): ClassLikeSymbol = this match
case Unit => summon[State].unitSymbol
case Int => ctx.builtins.Int
case Int31 => ctx.builtins.Int31
case Num => ctx.builtins.Num
case Str => ctx.builtins.Str
case Bool => ctx.builtins.Bool
case Array => ctx.builtins.Array

object ErasedType:
def ObjectRef: ErasedType.AnyRef = AnyRef(rsc = false, NoSymbol)

/** Maps a [[`ClassLikeSymbol`]] into the canonical [[`ErasedType`]]. */
def fromClsLikeSymbol(csym: ClassLikeSymbol, rsc: Bool)(using Ctx, State): ErasedType =
PrimitiveType.values.find(_.sym === csym) match
case _ if csym === ctx.builtins.Object => ObjectRef
case S(prim) => ErasedType.Primitive(prim)
case _ => ErasedType.AnyRef(rsc, csym)

/** Joins two possibly-unknown erased types and returning the result.
*
* - If both sides are known and equal, returns that known type.
* - If both sides are known and unequal, returns the top type [[`ErasedType.ObjectRef`]].
* - If either side is unknown (`N`), returns `N` - Representing a possibly-unknown type.
*/
// TODO(Derppening): Widen conflicting knowns to their common ancestor once erased types track subtyping,
// rather than going directly to the top type.
def join(lhs: Opt[ErasedType], rhs: Opt[ErasedType]): Opt[ErasedType] = (lhs, rhs) match
case (N, _) | (_, N) => N
case (S(l), S(r)) => if l == r then S(l) else S(ObjectRef)

/** A generics-erased type of the Block IR. */
enum ErasedType:
/**
* An reference to a class-like symbol.
*
* If `csym` is `NoSymbol`, this represents the top type (`Object`).
*
* - `rsc` is true if this reference is a resource class.
*/
case AnyRef(rsc: Bool, csym: ClassLikeSymbol | NoSymbol.type)

/** A reference to a function of a possibly-known shape. */
case FuncRef(params: Ls[Opt[ErasedType]], ret: Opt[ErasedType])

/** An primitive type. */
case Primitive(prim: PrimitiveType)

/** The symbol for this erased type. */
def sym(using Ctx, State): ClassLikeSymbol = this match
case AnyRef(_, csym: ClassLikeSymbol) => csym
case AnyRef(_, NoSymbol) => ctx.builtins.Object
case FuncRef(_, _) => ctx.builtins.Function
case Primitive(prim) => prim.sym

/** Trait representing a Block IR element that has an [[`ErasedType`]]. */
trait HasErasedType:
/** The [[`ErasedType`]] of this element, or `N` if the erased type is not known. */
def erasedType: Opt[ErasedType]

/** Similar to `erasedType`, but coerces to the top type if the specific erased type is not known. */
def erasedType_! : ErasedType = erasedType.getOrElse(ErasedType.ObjectRef)

/** A [[`HasErasedType`]] whose erased type can be populated exactly once post-construction. */
trait HasOnceMutableErasedType extends HasErasedType:
// Implementation Note: Provided for overriding classes to implement `erasedType` directly as an `override var`
def erasedType_=(newType: Opt[ErasedType]): Unit

/** Populates the erased type, or raises a soft assertion if the type was already populated. */
def populateErasedType(newType: ErasedType)(using Line, FileName, Raise): Unit =
// TODO(Derppening): Restore `erasedType.isEmpty` once JS sanitization is converted into a pass, allowing us to
// only lower each program once
softAssert(erasedType.forall(_ == newType), s"Cannot refine already-refined erased type $erasedType to $newType")
if erasedType.isEmpty then erasedType = S(newType)

/** A [[`HasOnceMutableErasedType`]] whose erased type can additionally be assigned multiple times, joining each
* observed type into the running erased type via [[`ErasedType.join`]]. */
trait HasManyMutableErasedType extends HasOnceMutableErasedType:
/** Tracks whether [[`observeErasedTypeAssign`]] has seen at least one assignment to this symbol. Needed to
* distinguish a never-observed symbol (whose `N` erased type is the join unit, to be seeded by the first
* observation) from an observed-but-poisoned one (whose `N` erased type is the absorbing top). */
private var erasedTypeObserved: Bool = false

/**
* Observes an assignment of a value with type `observed` to this symbol.
*
* Some symbols (e.g. [[`LocalVarSymbol`]]) can be assigned multiple times with values of differing types. The
* first observation of an otherwise-unset symbol seeds the erased type directly; every subsequent observation
* (and any observation joining an annotation-populated type) is folded in with [[`ErasedType.join`]]. Under that
* join an unknown (`N`) assignment is "poison": it widens the symbol to the unknown top type and is absorbing,
* while two differing known types widen to the definitive top type [[`ErasedType.ObjectRef`]].
*/
def observeErasedTypeAssign(observed: Opt[ErasedType]): Unit =
if !erasedTypeObserved && erasedType.isEmpty then
erasedType = observed
else
erasedType = ErasedType.join(erasedType, observed)
erasedTypeObserved = true

extension (lit: Literal)
def erasedType: ErasedType = lit match
case Tree.UnitLit(_) => ErasedType.Primitive(PrimitiveType.Unit)
case Tree.IntLit(_) => ErasedType.Primitive(PrimitiveType.Int)
case Tree.DecLit(_) => ErasedType.Primitive(PrimitiveType.Num)
case Tree.StrLit(_) => ErasedType.Primitive(PrimitiveType.Str)
case Tree.BoolLit(_) => ErasedType.Primitive(PrimitiveType.Bool)

sealed trait TrivialResult extends Result

sealed abstract class Result extends AutoLocated:
sealed abstract class Result extends AutoLocated, HasErasedType:
// // * Used for debugging locations:
// sealed abstract class Result extends AutoLocated with ProductWithExtraInfo:
// def extraInfo: Str = toLoc.toString
Expand Down Expand Up @@ -944,6 +1057,29 @@ sealed abstract class Result extends AutoLocated:
case Value.Lit(lit) => 0
case DynSelect(qual, fld, arrayIdx) => qual.size + fld.size

lazy val erasedType: Opt[ErasedType] = this match
case Tuple(_, _) => S(ErasedType.Primitive(PrimitiveType.Array))
case Record(_, _) => S(ErasedType.ObjectRef)
case Instantiate(_, cls, _) => cls.targetSymbol.flatMap(_.asClsOrMod).map(ErasedType.AnyRef(rsc = false, _))
case Value.SimpleRef(sym) => sym.erasedType
case Value.MemberRef(_, disamb: (ClassSymbol | ModuleOrObjectSymbol | TypeAliasSymbol)) => disamb.erasedType
case Value.This(clsOrMod: (ClassSymbol | ModuleOrObjectSymbol)) => clsOrMod.erasedType
case Value.Lit(lit) => S(lit.erasedType)
case Call(Value.SimpleRef(bs: BuiltinSymbol), argss) =>
bs.resultErasedType(argss.head.map(_.value.erasedType))
case Call(fun, _) => fun.targetSymbol match
case S(ts: TermSymbol) => ts.erasedType match
case S(ErasedType.FuncRef(_, ret)) => ret
case _ => N
case _ => N
// * A resolved selection has the type of the member it refers to (e.g. `this.field`); an
// * unresolved selection (dynamic field access) stays unknown.
case sel @ Select(_, _) => sel.symbol match
case S(ts: TermSymbol) => ts.erasedType
case S(d: (ClassSymbol | ModuleOrObjectSymbol | TypeAliasSymbol)) => d.erasedType
case _ => N
case _ => N

/* mayRaiseEffects indicates whether this call may raise effect (algebraic effect),
* regardless of whether the check for effect is inserted or not.
* Note that the check for effect is inserted during HandlerLowering and setting this to true
Expand Down Expand Up @@ -1085,7 +1221,7 @@ object Value:
case SimpleRef(l) => l
case MemberRef(bms, disamb) => bms
case This(sym) => sym

@deprecated("Use Value.SimpleRef, Value.MemberRef, or Value.This instead.")
object Ref:
def apply(l: ValueSymbol | NoSymbol.type, disamb: Opt[DefinitionSymbol[?]]): Value.RefLike =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1538,7 +1538,7 @@ class BlockSimplifier
def go(acc: Block => Block, args: List[(VarSymbol, Result)], mapping: Map[Symbol, Symbol]): Block =
args match
case Nil =>
val resSym = TempSymbol(N, "inlinedVal")
val resSym = TempSymbol(N, erasedType = N, "inlinedVal")
val copier = Copier(resSym, mapping)
val newBlk = copier.applyBlock(blk)
if extraArgss.isEmpty then
Expand All @@ -1550,7 +1550,7 @@ class BlockSimplifier
annotations = c.metadata.annotations.filterNot(_ == Annot.TailCall),
))))))
case (sym, value) :: argRest =>
val newSym = VarSymbol(sym.id)
val newSym = VarSymbol(sym.id, erasedType = N)
go(acc.assignScoped(newSym, value), argRest, mapping + (sym -> newSym))
go(blockBuilder, matchedArgs, Map.empty)
case _ => super.applyResult(r)(k)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@ class BufferableTransform()(using Ctx, State, Raise):
cls.bufferable.fold(super.applyDefn(defn)(k)): bufferable =>
val companionSym = ModuleOrObjectSymbol(DummyTypeDef(syntax.Mod), new Tree.Ident(cls.sym.nme))
val clsSizeSym = BlockMemberSymbol("size", Nil, false)
val clsSizeTermSym = TermSymbol(syntax.ImmutVal, S(companionSym), new Tree.Ident("size"))
val clsSizeTermSym = TermSymbol(syntax.ImmutVal, S(companionSym), new Tree.Ident("size"), erasedType = S(ErasedType.Primitive(PrimitiveType.Int)))
val pubFieldMap: Map[Symbol, Symbol] = cls.publicFields.toMap
val fields = cls.privateFields ++ cls.publicFields.map(_._2)
val fieldMap: Map[Symbol, Int] = fields.zipWithIndex.toMap
def mkSymbolReplacer(params: List[ParamList]): (List[ParamList], Map[SimpleSymbol, SimpleSymbol]) =
val allVars = params.flatMap(_.allParams).map(_.sym)
val varMap = allVars
.map: sym =>
(sym, VarSymbol(sym.id))
(sym, VarSymbol(sym.id, erasedType = N))
.toMap
def mapParam(p: Param) =
Param(p.flags, varMap(p.sym), p.sign, p.modulefulness)
(params.map(pl => ParamList(pl.flags, pl.params.map(mapParam), pl.restParam.map(mapParam))), varMap.toMap)
def mkFieldReplacer(buf: VarSymbol, baseIdx: VarSymbol, symMap: Map[SimpleSymbol, SimpleSymbol]) =
def getOffset(off: Int)(k: Path => Block): Block =
val idxSymbol = new TempSymbol(N, "idx")
val idxSymbol = new TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "idx")
Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(CallMetadata.defaultMlsFun),
k(DynSelect(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true))))
def assignToOffset(off: Int, r: Result, rst: Block) =
val idxSymbol = new TempSymbol(N, "idx")
val idxSymbol = new TempSymbol(N, erasedType = S(ErasedType.Primitive(PrimitiveType.Int)), "idx")
Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asSimpleRef, (baseIdx.asSimpleRef.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil) ne_:: Nil)(CallMetadata.defaultMlsFun),
AssignDynField(buf.asSimpleRef.selSN("buf"), idxSymbol.asSimpleRef, true, r, applyBlock(rst))))
new BlockTransformer(SymbolSubst.Id):
Expand Down Expand Up @@ -73,11 +73,11 @@ class BufferableTransform()(using Ctx, State, Raise):
k(res)
case _ => super.applyPath(p)(k)
def transformFunDefn(f: FunDefn, isCtor: Bool): FunDefn =
val buf = VarSymbol(new Tree.Ident("buf"))
val idx = VarSymbol(new Tree.Ident("idx"))
val buf = VarSymbol(new Tree.Ident("buf"), erasedType = N)
val idx = VarSymbol(new Tree.Ident("idx"), erasedType = S(ErasedType.Primitive(PrimitiveType.Int)))
val (newParams, symMap) = mkSymbolReplacer(f.params)
val blk = mkFieldReplacer(buf, idx, symMap).applyBlock(f.body)
FunDefn(f.owner, f.sym, TermSymbol(f.dSym.k, f.dSym.owner, f.dSym.id), PlainParamList(
FunDefn(f.owner, f.sym, TermSymbol(f.dSym.k, f.dSym.owner, f.dSym.id, erasedType = N), PlainParamList(
Param(FldFlags.empty, buf, N, Modulefulness.none) :: Param(FldFlags.empty, idx, N, Modulefulness.none) :: Nil) :: newParams,
if isCtor then Begin(blk, Return(idx.asSimpleRef)) else blk)(configOverride = f.configOverride, annotations = f.annotations)
val fakeCtor = transformFunDefn(FunDefn.withFreshSymbol(
Expand Down
6 changes: 3 additions & 3 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/DeadParamElim.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise):
val name = instId.mkFunName + s"$$${f.nme}"
f -> (
new BlockMemberSymbol(name, Nil, true),
new TermSymbol(Fun, N, Tree.Ident(name)))
new TermSymbol(Fun, N, Tree.Ident(name), erasedType = N))
.toMap)
end mkNewPolyFnSyms

Expand Down Expand Up @@ -335,12 +335,12 @@ class Rewrite(val deadParamElimSolver: DeadParamElimSolver)(using Raise):
case ParamList(flags, params, restParam) =>
val params2 = params.map:
case p =>
val newSym = new VarSymbol(Tree.Ident(p.sym.name))
val newSym = new VarSymbol(Tree.Ident(p.sym.name), erasedType = N)
refreshParamMap(p.sym) = newSym
Param(p.flags, newSym, p.sign, p.modulefulness)
val rest2 = restParam.map:
case p =>
val newSym = new VarSymbol(Tree.Ident(p.sym.name))
val newSym = new VarSymbol(Tree.Ident(p.sym.name), erasedType = N)
refreshParamMap(p.sym) = newSym
Param(p.flags, newSym, p.sign, p.modulefulness)
ParamList(flags, params2, rest2)
Expand Down
4 changes: 2 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/EtaExpansion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class EtaExpansionRewrite(val etaExpansionSolver: EtaExpansionSolver)(using Rais
targetShape.drop(existingShape.size).zipWithIndex.map:
case (count, idx) =>
val params = (0 until count).toList.map: i =>
Param.simple(new VarSymbol(new Tree.Ident(s"eta$$$idx$$$i")))
Param.simple(new VarSymbol(new Tree.Ident(s"eta$$$idx$$$i"), erasedType = N))
EtaParamList(
ParamList(ParamListFlags.empty, params, N),
params.map(p => Arg(N, p.sym.asSimpleRef)),
Expand All @@ -195,7 +195,7 @@ class EtaExpansionRewrite(val etaExpansionSolver: EtaExpansionSolver)(using Rais
Return(
Call(fun, (argss ++ activeEtaArgss).ne_!)(c.metadata))
case _ =>
val tmp = TempSymbol(N, "eta$res")
val tmp = TempSymbol(N, erasedType = N, "eta$res")
Scoped(
Set.single(tmp),
Assign(tmp, res2, Return(etaCall(tmp.asPath).withLocOf(res2))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class FirstClassFunctionTransformer

private def etaExpandPath(p: Path, params: ParamList)(k: Path => Block): Block =
val clsDef = generateFCFunctionClass(p, params)
val tmp = new TempSymbol(None)
val tmp = new TempSymbol(None, erasedType = S(ErasedType.AnyRef(rsc = false, clsDef.isym.asClsOrMod.get)))
val cls = clsDef.sym.asMemberRef(clsDef.isym)
Scoped(Set(clsDef.sym, tmp), Define(clsDef, Assign(tmp, Instantiate(false, cls, Nil :: Nil)(InstantiateMetadata.empty), k(tmp.asSimpleRef))))

Expand Down
Loading
Loading