TBH, I don't understand the jax code, it must rely on some library I don't know, stateless_call, attach_weights, etc don't seem very functional to me. In general, my implementation would look simpler: just passing in the weights from outside and take derivatives.
Maybe there are better solutions, I don't know. I'm still learning JAX.
-
-
The Linen version hides everything in the "optimizer", even the weights and gives back a new "optimizer". One might say it is a misnomer. Then one can replace various parts of the optimizer and create new optimizers.
- Show replies
New conversation -
Loading seems to be taking a while.
Twitter may be over capacity or experiencing a momentary hiccup. Try again or visit Twitter Status for more information.