I try to write the fastest way that I know of for matrix multiplication in CPU, GPU and TPU.
## NumPy on CPU
```python
import numpy as np
def matmul_cpu(N):
# Generate random matrices directly on the RAM
A = np.random.rand(N, N).astype(np.float32)
B = np.random.rand(N, N).astype(np.float32)
with timeit("Matrix multiplication"):
C = A @ B
return C
```
[Colab](https://colab.research.google.com/drive/1JqcNUXCBEWYNXhqr-0nJeMPgkf5LYk6Q?usp=sharing)
## CuPy on GPU
```python
import cupy as cp
def matmul_gpu(N):
# Generate random matrices directly on the GPU
A = cp.random.rand(N, N).astype(cp.float32)
B = cp.random.rand(N, N).astype(cp.float32)
with timeit("Matrix multiplication"):
C = A @ B
cp.cuda.Stream.null.synchronize()
return C
```
[Colab](https://colab.research.google.com/drive/1jpWrCi0G47HRiuRN7JzrS6hS3D5CqXhu?usp=sharing)
## JAX NumPy on TPU
```python
import jax
import jax.numpy as jnp
@jax.jit
def fast_matmul(A, B):
return A @ B
def matmul_tpu(N):
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key)
# Generate random matrices directly on the TPU
A = jax.random.uniform(key1, (N, N), dtype=jnp.float32)
B = jax.random.uniform(key2, (N, N), dtype=jnp.float32)
with timeit("Warmup (Compilation)"):
fast_matmul(A, B).block_until_ready()
with timeit("Matrix multiplication"):
C = fast_matmul(A, B).block_until_ready()
return C
```
[Colab](https://colab.research.google.com/drive/1pn7vsMYoHu_k5KqdWcjqVfp6mp5fnUWN?usp=sharing)
## Comparison
When running them on Google Colab, this is the result on a 8192 x 8192 matrix.
| | Time |
| --------------: | -------: |
| CPU | 15.828s |
| NVIDIA T4 GPU | 0.32121s |
| V5E1 TPU (cold) | 0.60472s |
| V5E1 TPU (warm) | 0.00705s |