TBH, I don't understand the jax code, it must rely on some library I don't know, stateless_call, attach_weights, etc don't seem very functional to me. In general, my implementation would look simpler: just passing in the weights from outside and take derivatives.
-
-
Confused about attach_weights.. Is this (non-functional) call really needed if we are passing around the model state anyway?
-
Yep, my point is that one can just pass around some composite state like in a monad. In Linen that state is essentially the "optimizer", but it has a state component which is the actual model state.
End of conversation
New conversation -
Loading seems to be taking a while.
Twitter may be over capacity or experiencing a momentary hiccup. Try again or visit Twitter Status for more information.