TBH, I don't understand the jax code, it must rely on some library I don't know, stateless_call, attach_weights, etc don't seem very functional to me. In general, my implementation would look simpler: just passing in the weights from outside and take derivatives.
-
-
The management of weights is non-trivial from scratch, but in the framework I use (Linen), the initializer collects all the weight matrices as a single pytree and then I just pass them in to the model (which is essentially a function), when I compute the gradient.
-
If I don't have any non-trainable weights, I get this, which is much simpler. However, adding non-trainable weights in the pictures adds significant complication. The same applies to metric state updates as well.pic.twitter.com/oQ8zbfoq3W
- Show replies
New conversation -
Loading seems to be taking a while.
Twitter may be over capacity or experiencing a momentary hiccup. Try again or visit Twitter Status for more information.