← writing

Rotary Position Embeddings

RoPE or Rotary Position Embedding is a technique we implement that allows models to learn the positions of tokens. This is sort of a direct upgrade over sinusoidal embeddings.

The main idea remains the same. Across the sequence length, we want to reason about how far apart the tokens are (relative positioning) and their positions in the entire sequence (absolute positioning).

For this, we employ the same workings as sinusoidal embeddings but instead of sine and cosine as an additional step, we opt to directly inject the positions by rotating the token vectors while preserving the magnitudes.

Rotation in the Complex Plane

This can be done in the complex plane much more easily than in the real plane. The real plane would have a 2D rotation matrix where each point (x,y)(x, y) would be matrix multiplied with the matrix (cosθsinθsinθcosθ)\begin{pmatrix}\cos \theta & -\sin \theta\\\sin \theta & \cos \theta\end{pmatrix} which would rotate the thing.

Now I don’t really like doing it this way so I prefer using the complex plane math using Euler’s formula: eiθ=cosθ+isinθe^{i \theta} = \cos \theta + i \sin \theta

In the complex plane, if θ=0\theta = 0, you are at (1, 0), if θ=π2\theta = \frac{\pi}{2}, you are at ii (i.e. coordinate (0, 1)). When you multiply a point z=x+iyz=x+iy in the complex plane, you don’t change its magnitude, only its direction. To prove this, we can expand the terms:

(x+iy)(cosθ+isinθ)=xcosθ+ixsinθ+iycosθ+i2ysinθ=(xcosθysinθ)+(ixsinθ+iycosθ)(i2=1)=(xcosθysinθ)+i(xsinθ+ycosθ)\begin{align*} (x+iy)(\cos \theta+i\sin \theta) &= x\cos \theta+ix \sin \theta+iy \cos \theta+i^2y\sin \theta \\ &= (x\cos\theta-y\sin\theta)+(ix\sin\theta+iy\cos\theta) && \because (i^2 = -1)\\ &= (x\cos\theta-y\sin\theta)+ i(x\sin\theta+y\cos\theta) \end{align*}

This is the same as the rotation matrix. Instead of a matrix multiplication, we can do regular multiplication in the complex plane.

Converting Real Vectors into Complex Numbers

Now, since all of our vectors are real-valued, we need to convert them into complex numbers before performing this multiplication. So we pair two numbers to form the real and imaginary parts of a complex number. As such, the number of complex elements becomes dim/2 so we only need to calculate frequencies for half the dimensions.

freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))

Notice how we use float() to ensure numerical stability.

We also use an exponential decay: as the dimension index increases, theta (how much to rotate) gets smaller.

RoPE Frequency Progression across Dimensions A plot showing how the rotation frequency decreases exponentially as the dimension index increases.0163248648096112128token positiond=0d=1d=2d=3d=4d=8d=16d=32rotary value + offset

We also normalize the dimension indices by dividing by dim to scale the spectrum and prevent it from shrinking to 0. The / dim keeps the exponent between 0 and 1.

A Minimal Rotation Example

Putting that into code:

>>> x_real = torch.tensor([[1.0, 0.0]])
>>> x_complex = torch.view_as_complex(x_real)
>>> theta = torch.tensor([torch.pi / 2])
>>> rotator = torch.polar(torch.ones_like(theta), theta)
>>> rotated_complex = x_complex * rotator
>>> rotated_real = torch.view_as_real(rotated_complex)
>>> print(f"Original Real Pair: {x_real.tolist()}")
... print(f"Complex Form:       {x_complex.item()}")
... print(f"Rotated Complex:    {rotated_complex.item()}")
... print(f"Final Real Pair:    {rotated_real.tolist()}")
Original Real Pair: [[1.0, 0.0]]
Complex Form:       (1+0j)
Rotated Complex:    (-4.371138828673793e-08+1j) # roughly 0+1j
Final Real Pair:    [[-4.371138828673793e-08, 1.0]] # roughly [0, 1]

Computing Frequencies Across Positions

Then we compute the outer product of the frequencies with the position indices up to the maximum sequence length. FYI, an outer product multiplies every pair of elements from two vectors.

>>> a = torch.tensor([1, 2, 3])
>>> torch.outer(a, a)
tensor([[1, 2, 3],
        [2, 4, 6],
        [3, 6, 9]])

And the final code:

freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)

Shape Gymnastics

Now for the shape gymnastics.

head_dim = xq.shape[-1]
assert head_dim % 2 == 0, "head_dim should be even"

# xq = (b, sl, nh, hd) -> (b, sl, nh, n/2, 2) -> reinterpret to Complex
complex_shape = (*xq.shape[:-1], head_dim // 2, 2)

xq_ = torch.view_as_complex(xq.float().reshape(complex_shape))
xk_ = torch.view_as_complex(xk.float().reshape(complex_shape))

# reshape for broadcasting (1, seq_len, 1, head_dim // 2)
freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)

seq_len = xq_.shape[1]

# if maximum context length is N and we're at token M where M < N,
# then we only need the first M frequencies.
freqs = freqs_cis[:, start_pos : start_pos + seq_len]

# rotate by multiplying in complex space
xq_out = torch.view_as_real(xq_ * freqs).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs).flatten(3)

# the output was cast to fp32, needs to go back to original dtype
return xq_out.type_as(xq), xk_out.type_as(xk)