TBH, I don't understand the jax code, it must rely on some library I don't know, stateless_call, attach_weights, etc don't seem very functional to me. In general, my implementation would look simpler: just passing in the weights from outside and take derivatives.
-
-
Maybe there are better solutions, I don't know. I'm still learning JAX.
-
The Linen version hides everything in the "optimizer", even the weights and gives back a new "optimizer". One might say it is a misnomer. Then one can replace various parts of the optimizer and create new optimizers.
- Show replies
New conversation -
-
-
OK, I see what you mean by non-trainable weights. I think Linen just hides their updates in the optimizer. I need to check.
-
It's not 'hidden' in Linen, but I don't mind the extra bit of extra bit of verbosity. Haiku and Linen are similar in their handling of this, see https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/train.py#L148-L190 … andhttps://github.com/google/flax/blob/main/examples/imagenet/train.py#L113-L165 …
- Show replies
New conversation -
Loading seems to be taking a while.
Twitter may be over capacity or experiencing a momentary hiccup. Try again or visit Twitter Status for more information.