Opens profile photo
Follow
Click to Follow Birchlabs
Birchlabs
@Birchlabs
Very Full-stack Software Developer at Some FinTech Company. 大学で日本語を勉強した。part-time fumo researcher.
birchlabs.co.ukJoined August 2008

Birchlabs’s Tweets

matmul is just lots of dot products. matmul(x, y) is "every row of x dot-product its corresponding column in y". you can express: matmul(x, y) as lots of dot products: ((x.unsqueeze(-2).expand(-1, y.size(-1), -1)) * y.T).sum(-1) momentarily uses x.size(-1) times more memory tho
Python REPL demonstrating that the two formulations of x @ y return the same result.

from torch import tensor, matmul
x = tensor([[1,  2],
            [3, -1],
            [-2, 3]])
y = tensor([[1, -1, 2, 3],
            [2, 1, -1, 2]])
matmul(x, y)
((x.unsqueeze(-2).expand(-1, y.size(-1), -1)) * y.T).sum(-1)
1
11
Measured torch built-in flash attention:
Quote Tweet
tried out pytorch's built-in Flash Attention on 4090. no flash attention: 4.10it/s torch.nn.MultiheadAttention: 6.20~6.23it/s xformers flash attention: 6.32~6.35it/s batch-of-5 images, 528x768 res, 22 steps, 6 batches github.com/Birch-san/diff
Show this thread
1
Show this thread
ERA-Solver seems to absolutely slap. Its FID at 10 steps is lower than what the other samplers achieved at 100 steps. Note: their FID measures only the *older* DPM-Solvers; DPM-Solver++ is only measured in context of computation time. DEIS not mentioned at all.
Quote Tweet
ERA-Solver: Error-Robust Adams Solver for Fast Sampling of Diffusion Probabilistic Models abs: arxiv.org/abs/2301.12935
Image
2
12
I should clarify: I don't mean that attention got 50% faster. I mean that *image generation* got 50% faster; attention is one part of that. On Mac I measured that attention accounts for just under half of where the Unet spent its time, for 512x512 images.
Show this thread
I also wanna check out pytorch's built-in flash attention and see how it compares to xformers. compiled pytorch 2.0 alpha with USE_FLASH_ATTENTION=ON; will give it a spin soon
pytorch build flags indicating USE_FLASH_ATTENTION: ON
5
25
Show this thread
I never found the pen, so I did with numpy and a footswitch
Quote Tweet
it actually worked! you can split floats into integer mantissae and exponents, and compute matmul via elementwise products. exponents can be multiplied cheaply (it's just addition!) new datatype for ML training? gist.github.com/Birch-san/4f49
# two formulations which achieve the same result:

# regular matmul
mat = np.array([[1.3e2, 8.e2], [1.6e1, 5.e-1]])
vec = np.array([2.1e2, 3.6e-4])
mat @ vec

# matmul via separate elementwise products of mantissa and exponent, then recombining when we want to return to canonical representation
mat_m = np.array([[13, 8], [16, 5]])
mat_e = np.array([[1, 2], [0, -1]])
vec_m = np.array([21, 36])
vec_e = np.array([1,-5])
mantissa_elementwise_product = mat_m * vec_m
exponent_elementwise_product = mat_e + vec_e
elementwise_product = mantissa_elementwise_product * np.array([10.])**exponent_elementwise_product
elementwise_product.sum(1)
2
Show this thread
2 npm packages printed adverts during postinstall one author was looking for a job the other was hiring I should introduce them
11
I actually prefer the undersampled image to the converged image. this one is the same seed with a 15-step Karras schedule running the full gamut of sigmas, rho=7. k-diffusion DPM-Solver++ (2M) with (2S) LMS warmup
substantially more converged image
1
Show this thread