Skip to content

Integrate with Mooncake#278

Open
mofeing wants to merge 3 commits into
masterfrom
feature/mooncake-support
Open

Integrate with Mooncake#278
mofeing wants to merge 3 commits into
masterfrom
feature/mooncake-support

Conversation

@mofeing

@mofeing mofeing commented Jan 8, 2025

Copy link
Copy Markdown
Collaborator

CC @willtebbutt

i'm having also a problem when trying to compute the gradient.

this code:

A = Tensor(rand(2,3), (:i,:j))
B = Tensor(rand(3,4), (:j,:k))

f(a,b) = sum(contract(a,b))

rule = build_rrule(Tuple{typeof(f), typeof(A), typeof(B)})
Mooncake.value_and_gradient!!(rule, contract, A, B)

gives this error:

ERROR: ArgumentError: signature of arguments, Tuple{Mooncake.CoDual{typeof(contract), NoFData}, Mooncake.CoDual{Tensor{Float64, 2, Matrix{Float64}}, Mooncake.FData{@NamedTuple{data::Matrix{Float64}, inds::Vector{NoTangent}}}}, Mooncake.CoDual{Tensor{Float64, 2, Matrix{Float64}}, Mooncake.FData{@NamedTuple{data::Matrix{Float64}, inds::Vector{NoTangent}}}}}, not equal to signature required by rule, Tuple{Mooncake.CoDual{typeof(f), NoFData}, Mooncake.CoDual{Tensor{Float64, 2, Matrix{Float64}}, Mooncake.FData{@NamedTuple{data::Matrix{Float64}, inds::Vector{NoTangent}}}}, Mooncake.CoDual{Tensor{Float64, 2, Matrix{Float64}}, Mooncake.FData{@NamedTuple{data::Matrix{Float64}, inds::Vector{NoTangent}}}}}.
Stacktrace:
 [1] __verify_sig(rule::Mooncake.DerivedRule{…}, fx::Tuple{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interface.jl:27
 [2] __value_and_gradient!!(::Mooncake.DerivedRule{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interface.jl:81
 [3] value_and_gradient!!(::Mooncake.DerivedRule{…}, ::Function, ::Tensor{…}, ::Tensor{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interface.jl:145
 [4] top-level scope
   @ REPL[25]:1
Some type information was truncated. Use `show(err)` to see complete types.

@willtebbutt

Copy link
Copy Markdown

gives this error:

Is the problem just that f should be passed to value_and_gradient!!, rather than contract?

@mofeing

mofeing commented Jan 8, 2025

Copy link
Copy Markdown
Collaborator Author

ah right 🤦

well, that uncovers another problem

julia> Mooncake.value_and_gradient!!(rule, f, A, B)
ERROR: MethodError: no method matching (::TenetChainRulesCoreExt.var"#contract_pullback#61"{…})(::ChainRulesCore.Tangent{…})
The function `contract_pullback` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  (::TenetChainRulesCoreExt.var"#contract_pullback#61")(::ChainRulesCore.AbstractThunk)
   @ TenetChainRulesCoreExt ~/Developer/Tenet.jl/ext/TenetChainRulesCoreExt/rrules.jl:117
  (::TenetChainRulesCoreExt.var"#contract_pullback#61")(::Tensor)
   @ TenetChainRulesCoreExt ~/Developer/Tenet.jl/ext/TenetChainRulesCoreExt/rrules.jl:110
  (::TenetChainRulesCoreExt.var"#contract_pullback#61")(::AbstractVector)
   @ TenetChainRulesCoreExt ~/Developer/Tenet.jl/ext/TenetChainRulesCoreExt/rrules.jl:116
  ...

Stacktrace:
 [1] (::Mooncake.var"#pb!!#291"{…})(y_rdata::NoRData)
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/tools_for_rules.jl:295
 [2] (::Mooncake.RRuleWrapperPb{Mooncake.var"#pb!!#291"{…}, Mooncake.LazyZeroRData{…}})(dy::NoRData)
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interpreter/s2s_reverse_mode_ad.jl:299
 [3] __run_rvs_pass!(::Type, ::Type{…}, ::Mooncake.RRuleWrapperPb{…}, ::Base.RefValue{…}, ::Nothing, ::Vararg{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interpreter/s2s_reverse_mode_ad.jl:820
 [4] f
   @ ./REPL[42]:1 [inlined]
 [5] (::Tuple{Mooncake.Stack{…}, Base.RefValue{…}, Mooncake.RRuleZeroWrapper{…}, Mooncake.Stack{…}})(none::Any)
   @ Base.Experimental ./<missing>:0
 [6] Pullback
   @ ~/.julia/packages/Mooncake/oBjQd/src/interpreter/s2s_reverse_mode_ad.jl:855 [inlined]
 [7] __value_and_gradient!!(::Mooncake.DerivedRule{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interface.jl:85
 [8] value_and_gradient!!(::Mooncake.DerivedRule{…}, ::Function, ::Tensor{…}, ::Tensor{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interface.jl:145
 [9] top-level scope
   @ REPL[44]:1
Some type information was truncated. Use `show(err)` to see complete types.

i probably need to add a method to contract_pullback in here

function contract_pullback(c̄::Tensor)
= @thunk proj_a(contract(c̄, conj(b); out=inds(a)))
= @thunk proj_b(contract(conj(a), c̄; out=inds(b)))
return (NoTangent(), ā, b̄)
end
contract_pullback(c̄::AbstractArray) = contract_pullback(Tensor(c̄, inds(c)))
contract_pullback(c̄::AbstractVector) = contract_pullback(Tensor(c̄, inds(c)))
contract_pullback(c̄::AbstractThunk) = contract_pullback(unthunk(c̄))
but i'm not sure what type of object is Mooncake passing to the ChainRules callbacks because it's an error i haven't found before with Zygote and ChainRulesTestUtils

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants