I try to write the fastest way that I know of for matrix multiplication in CPU, GPU and TPU. ## NumPy on CPU ```python import numpy as np def matmul_cpu(N): # Generate random matrices directly on the RAM A = np.random.rand(N, N).astype(np.float32) B = np.random.rand(N, N).astype(np.float32) with timeit("Matrix multiplication"): C = A @ B return C ``` [Colab](https://colab.research.google.com/drive/1JqcNUXCBEWYNXhqr-0nJeMPgkf5LYk6Q?usp=sharing) ## CuPy on GPU ```python import cupy as cp def matmul_gpu(N): # Generate random matrices directly on the GPU A = cp.random.rand(N, N).astype(cp.float32) B = cp.random.rand(N, N).astype(cp.float32) with timeit("Matrix multiplication"): C = A @ B cp.cuda.Stream.null.synchronize() return C ``` [Colab](https://colab.research.google.com/drive/1jpWrCi0G47HRiuRN7JzrS6hS3D5CqXhu?usp=sharing) ## JAX NumPy on TPU ```python import jax import jax.numpy as jnp @jax.jit def fast_matmul(A, B): return A @ B def matmul_tpu(N): key = jax.random.PRNGKey(0) key1, key2 = jax.random.split(key) # Generate random matrices directly on the TPU A = jax.random.uniform(key1, (N, N), dtype=jnp.float32) B = jax.random.uniform(key2, (N, N), dtype=jnp.float32) with timeit("Warmup (Compilation)"): fast_matmul(A, B).block_until_ready() with timeit("Matrix multiplication"): C = fast_matmul(A, B).block_until_ready() return C ``` [Colab](https://colab.research.google.com/drive/1pn7vsMYoHu_k5KqdWcjqVfp6mp5fnUWN?usp=sharing) ## Comparison When running them on Google Colab, this is the result on a 8192 x 8192 matrix. | | Time | | --------------: | -------: | | CPU | 15.828s | | NVIDIA T4 GPU | 0.32121s | | V5E1 TPU (cold) | 0.60472s | | V5E1 TPU (warm) | 0.00705s |