Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 58 additions & 42 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/ups/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,40 +24,45 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter
import Pattern.*

/** A previously-computed matcher result for one field of the current
* multi-matcher. The runtime representation is shape-dependent:
* singleton-label matchers return the label's value directly, while
* multi-label matchers return a record keyed by label field names.
* multi-matcher. In full mode the value also carries the original field
* input, which is needed when a successful field pattern preserves its
* scrutinee. Match-only mode never consumes that input, so it stores the
* submatcher result directly.
*/
private final case class MatcherResult(symbol: VarSymbol, labels: Set[Label]):
private def result: Term = sel(symbol.safeRef, "result")
def input: Term = sel(symbol.safeRef, "input")
private def result(using ResultMode): Term =
if isMatchOnly then symbol.safeRef else sel(symbol.safeRef, "result")
def input(using ResultMode): Term =
softAssert(!isMatchOnly,
"Match-only field matcher results should not expose their input.")
sel(symbol.safeRef, "input")
/** Read the result for one label from this matcher result, abstracting over
* the singleton direct-return optimization.
*/
def select(label: Label): Term =
def select(label: Label)(using ResultMode): Term =
if labels.size is 1 then result
else sel(result, label.asFieldName)
/** Produce the default failure value for this matcher result with the same
* shape that a successful submatcher call would have produced.
*/
def default(using ResultMode): Term =
val result = labels.toList match
case label :: Nil => emptyMatchResult("empty")
case labels =>
Rcd(false, labels.map: label =>
RcdField(str(label.asFieldName), emptyMatchResult("empty")))
rcd(
RcdField(str("input"), `null`),
RcdField(str("result"), result)
)

def default(using ResultMode): Term = matcherResult(`null`, labels.toList match
case label :: Nil => emptyMatchResult("empty")
case labels =>
Rcd(false, labels.map: label =>
RcdField(str(label.asFieldName), emptyMatchResult("empty"))))

/** Make a match result record containing `input` and `result` fields. */
private def matcherResult(input: => Term, result: => Term)(using ResultMode): Term =
if isMatchOnly then result
else rcd(RcdField(str("input"), input), RcdField(str("result"), result))
private def bool(value: Bool): Term = Term.Lit(BoolLit(value))

private def isMatchOnly(using mode: ResultMode): Bool = mode is ResultMode.MatchOnly

private def emptyMatchResult(reason: Str)(using mode: ResultMode): Term =
if isMatchOnly then bool(false) else makeMatchFailure(str(reason))

private def nullifyEmptyBindings(bindings: Term): Term = bindings match
case Rcd(false, Nil) => `null`
case bindings => bindings
Expand All @@ -81,11 +86,11 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter

extension (head: Head)
/** Create a flat pattern that can be used in the UCS expressions. */
def toFlatPattern: FlatPattern = head match
def toFlatPattern(arguments: Opt[Ls[(LocalVarSymbol, Opt[Loc])]]): FlatPattern = head match
case lit: syntax.Literal => FlatPattern.Lit(lit)
case sym: (ClassSymbol | ModuleOrObjectSymbol) =>
val constructor = reference(sym, head.toLoc).getOrElse(Term.Error().withLocOf(head))
FlatPattern.ClassLike(constructor, sym, N, false)(Tree.Dummy)
FlatPattern.ClassLike(constructor, sym, arguments, false)(Tree.Dummy)
def showDbg: Str = head match
case lit: syntax.Literal => lit.idStr
case sym: ClassLikeSymbol => sym.nme
Expand Down Expand Up @@ -150,15 +155,25 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter
val scrutinee = VarSymbol(Ident("input"))
// Assemble branches for constructors and literals.
val branches = heads.map: head =>
// Weird. Removing type annotations caused type errors.
val specialized = expandedPatterns.specializeSet(S(head))
val consequent = Split.Else(multiMatcherBranch(specialized, scrutinee))
Branch(scrutinee.safeRef, head.toFlatPattern, consequent)
lazy val empty = (N: Opt[Ls[(LocalVarSymbol, Option[Loc])]], Map.empty[Ident | Int, LocalVarSymbol])
val (classFieldArguments, classFields) = head match
case symbol: ClassSymbol => symbol.defn.getOrElse(lastWords(s"Missing definition for symbol `${symbol.nme}`.")).paramsOpt match
case N => empty
case S(params) =>
val empty: (Ls[(LocalVarSymbol, Opt[Loc])], Map[Ident | Int, LocalVarSymbol]) = (Nil, Map.empty)
val (arguments, fields) = params.params.foldLeft(empty): (acc, param) =>
val (argAcc, fieldAcc) = acc
val fieldSymbol = TempSymbol(N, param.sym.nme)
(argAcc :+ (fieldSymbol, param.toLoc), fieldAcc + ((param.sym.id: Ident | Int) -> fieldSymbol))
(S(arguments), fields)
case _: (syntax.Literal | ModuleOrObjectSymbol) => empty
val consequent = Split.Else(multiMatcherBranch(specialized, scrutinee, classFields))
Branch(scrutinee.safeRef, head.toFlatPattern(classFieldArguments), consequent)
// Assemble the default branch.
val default =
// Weird. Removing type annotations caused type errors.
Comment thread
chengluyu marked this conversation as resolved.
val specialized = expandedPatterns.specializeSet(N)
Split.Else(multiMatcherBranch(specialized, scrutinee))
Split.Else(multiMatcherBranch(specialized, scrutinee, Map.empty))
// Make a split that tries all branches in order.
val topmostSplit = branches.foldRight(default)(_ ~: _)
val bodyTerm = SynthIf(topmostSplit)
Expand All @@ -167,7 +182,8 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter

def multiMatcherBranch(
patterns: Set[(Label, SpPat)],
scrutinee: LocalVarSymbol
scrutinee: LocalVarSymbol,
knownFields: Map[Ident | Int, LocalVarSymbol]
)(using ResultMode): Blk = trace(
pre = s"multiMatcherBranch: scrutinee = ${scrutinee} | patterns = ${
patterns.iterator.map: (label, pattern) =>
Expand All @@ -186,20 +202,20 @@ class Compiler(using Context)(using tl: TL)(using Ctx, State, Raise) extends Ter
log(s"subPattern for field ${field.showDbg}: ${
subPatterns.iterator.map(_.showDbg).mkString("{", ", ", "}")}")
val subMatcherSymbol = buildMultiMatcher(subPatterns)
val conditional =
// Check the presence of the field, and call the matcher if it exists.
val fieldIdent: Ident = field.asIdent
val fieldSymbol = TempSymbol(N, fieldIdent.name)
val fieldTest = FlatPattern.Record((fieldIdent -> fieldSymbol) :: Nil)
val consequent = Split.Else:
val resultTerm = app(subMatcherSymbol.safeRef, tup(fld(fieldSymbol.safeRef)), "result")
rcd(
RcdField(str("input"), fieldSymbol.safeRef),
RcdField(str("result"), resultTerm)
)
val branch = Branch(scrutinee.safeRef, fieldTest, consequent)
SynthIf(branch ~: Split.Else(subScrutinee.default))
LetDecl(subScrutinee.symbol, Nil) :: DefineVar(subScrutinee.symbol, conditional) :: Nil
val makeResult = (fieldSymbol: LocalVarSymbol) => matcherResult(
Comment thread
chengluyu marked this conversation as resolved.
Outdated
fieldSymbol.safeRef,
app(subMatcherSymbol.safeRef, tup(fld(fieldSymbol.safeRef)), "result"))
val result = knownFields.get(field) match
case S(fieldSymbol) => makeResult(fieldSymbol)
case N =>
// Check the presence of the field, and call the matcher if it exists.
val fieldIdent: Ident = field.asIdent
val fieldSymbol = TempSymbol(N, fieldIdent.name)
val fieldTest = FlatPattern.Record((fieldIdent -> fieldSymbol) :: Nil)
val consequent = Split.Else(makeResult(fieldSymbol))
val branch = Branch(scrutinee.safeRef, fieldTest, consequent)
SynthIf(branch ~: Split.Else(subScrutinee.default))
LetDecl(subScrutinee.symbol, Nil) :: DefineVar(subScrutinee.symbol, result) :: Nil
.toList
// For each pattern, we compile a split and bind the result to a variable.
// The variable will be a field of the output record.
Expand Down
198 changes: 198 additions & 0 deletions hkmc2/shared/src/test/mlscript/ucs/patterns/CompiledClassPatterns.mls
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
:js

open annotations

data class Add(val lhs, val rhs)

:sir
case Add(0, 0) then 1
//│ ———————————————| Lowered IR |———————————————————————————————————————————————————————————————————————
//│ let lambda;
//│ @private
//│ define lambda as fun lambda⁰(caseScrut) {
//│ let arg$Add$0$, arg$Add$1$;
//│ match caseScrut
//│ Add⁰ =>
//│ set arg$Add$0$ = caseScrut.lhs⁰;
//│ set arg$Add$1$ = caseScrut.rhs⁰;
//│ match arg$Add$0$
//│ 0 =>
//│ match arg$Add$1$
//│ 0 =>
//│ return 1
//│ else
//│ throw new globalThis⁰.Error⁰("match error")
//│ end
//│ else
//│ throw new globalThis⁰.Error⁰("match error")
//│ end
//│ else
//│ throw new globalThis⁰.Error⁰("match error")
//│ end
//│ };
//│ return lambda⁰
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ = fun

:sir
case @compile Add(0, 0) then 1
//│ ———————————————| Lowered IR |———————————————————————————————————————————————————————————————————————
//│ let lambda;
//│ @private
//│ define lambda as fun lambda³(caseScrut) {
//│ let matcher__Addʿ0ꓹ0ʾ$, matcher__0$, matchSuccess, lambda1, lambda2;
//│ define lambda1 as fun lambda¹(input) {
//│ let p_1$, p_1$1, tmp, tmp1;
//│ match input
//│ 0 =>
//│ set tmp = true;
//│ set p_1$ = tmp;
//│ return p_1$
//│ else
//│ set tmp1 = false;
//│ set p_1$1 = tmp1;
//│ return p_1$1
//│ end
//│ };
//│ set matcher__0$ = lambda¹;
//│ define lambda2 as fun lambda²(input) {
//│ let lhs, rhs, lhs1, rhs1, p_0$, result1$, result1$1, p_0$1, tmp, tmp1;
//│ match input
//│ Add⁰ =>
//│ set lhs = input.lhs⁰;
//│ set rhs = input.rhs⁰;
//│ set lhs1 = matcher__0$(lhs);
//│ set rhs1 = matcher__0$(rhs);
//│ set result1$1 = lhs1;
//│ match result1$1
//│ true =>
//│ set result1$ = rhs1;
//│ match result1$
//│ true =>
//│ set tmp = true;
//│ end
//│ else
//│ set tmp = false;
//│ end
//│ end
//│ else
//│ set tmp = false;
//│ end
//│ set p_0$ = tmp;
//│ return p_0$
//│ else
//│ set tmp1 = false;
//│ set p_0$1 = tmp1;
//│ return p_0$1
//│ end
//│ };
//│ set matcher__Addʿ0ꓹ0ʾ$ = lambda²;
//│ set matchSuccess = matcher__Addʿ0ꓹ0ʾ$(caseScrut);
//│ match matchSuccess
//│ true =>
//│ return 1
//│ else
//│ throw new globalThis⁰.Error⁰("match error")
//│ end
//│ };
//│ return lambda³
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ = fun

data class Box(val value)

:soir
Comment thread
chengluyu marked this conversation as resolved.
case @compile (Box(0) | Box(1)) then 1
//│ ——————————————| Optimized IR |——————————————————————————————————————————————————————————————————————
//│ let lambda;
//│ @private
//│ define lambda as fun lambda⁴(caseScrut) {
//│ let value, result1$, result2$;
//│ match caseScrut
//│ Box⁰ =>
//│ let inlinedVal;
//│ set value = caseScrut.value⁰;
//│ match value
//│ 0 =>
//│ set inlinedVal = { "p_1": true, "p_2": false };
//│ end
//│ 1 =>
//│ set inlinedVal = { "p_1": false, "p_2": true };
//│ end
//│ else
//│ set inlinedVal = { "p_1": false, "p_2": false };
//│ end
//│ set result1$ = inlinedVal.p_1﹖;
//│ match result1$
//│ true =>
//│ return 1
//│ else
//│ set result2$ = inlinedVal.p_2﹖;
//│ match result2$
//│ true =>
//│ return 1
//│ else
//│ throw new globalThis⁰.Error⁰("match error")
//│ end
//│ end
//│ else
//│ throw new globalThis⁰.Error⁰("match error")
//│ end
//│ };
//│ return lambda⁴
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ = fun

:soir
Comment thread
chengluyu marked this conversation as resolved.
Outdated
case @compile (Box(0) & Box(0 | 1)) then 1
//│ ——————————————| Optimized IR |——————————————————————————————————————————————————————————————————————
//│ let lambda;
//│ @private
//│ define lambda as fun lambda⁵(caseScrut) {
//│ let value, result1$, result2$;
//│ match caseScrut
//│ Box⁰ =>
//│ let inlinedVal;
//│ set value = caseScrut.value⁰;
//│ match value
//│ 0 =>
//│ set inlinedVal = { "p_1": true, "p_2": true };
//│ end
//│ 1 =>
//│ set inlinedVal = { "p_1": false, "p_2": true };
//│ end
//│ else
//│ set inlinedVal = { "p_1": false, "p_2": false };
//│ end
//│ set result1$ = inlinedVal.p_1﹖;
//│ match result1$
//│ true =>
//│ set result2$ = inlinedVal.p_2﹖;
//│ match result2$
//│ true =>
//│ return 1
//│ else
//│ throw new globalThis⁰.Error⁰("match error")
//│ end
//│ else
//│ throw new globalThis⁰.Error⁰("match error")
//│ end
//│ else
//│ throw new globalThis⁰.Error⁰("match error")
//│ end
//│ };
//│ return lambda⁵
//│ —————————————————| Output |—————————————————————————————————————————————————————————————————————————
//│ = fun

class Add2(val x, val y, z)

:e
case @compile (Add2(0, _, _) | Add2(_, 1, _)) then 0
//│ ╔══[COMPILATION ERROR] Parameter `z` is not accessible.
//│ ║ l.188: class Add2(val x, val y, z)
//│ ╙── ^
//│ ╔══[COMPILATION ERROR] Parameter `z` is not accessible.
//│ ║ l.188: class Add2(val x, val y, z)
//│ ╙── ^
//│ = fun
2 changes: 1 addition & 1 deletion hkmc2/shared/src/test/mlscript/ups/SimpleTransform.mls
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ Infix("**", Literal(1), Literal(2)) is EvaluatedTerm
fun evaluate(term) = if term is
(@compile EvaluatedTerm) as v then Some(v)
else None
//│ Lines of IR: 3624
//│ Lines of IR: 3136


evaluate of Literal(1)
Expand Down
Loading
Loading