Rotary Positional Embeddings
RoPE or Rotary Positional Embedding is a technique we implement that allows models to learn the positions of tokens. This is sort of a direct upgrade over sinusoidal embeddings.
The main idea remains the same. Across the sequence length, we want to reason between how far apart the tokens are (relative positioning) and its position in the entire sequence (absolute positioning).
For this we employ the same workings as sinusoidal embeddings but instead of sine and cosine as an additional step, we opt to directly inject the positions by rotating the token vectors while preserving the magnitudes.
This can be done in the complex plane much more easily than in the real plane. The real plane would have a 2D rotation matrix where each point would be matrix multiplied with the matrix which would rotate the thing.
Now I don’t really like doing it this way so I prefer using the complex plane math. This is rooted in euler’s formula of: . In the complex plane, if , you are at (1,0), if , you are at (0, ). When you multiply a point in the complex plane, you don’t change its magnitude only its direction. To prove this, we can expand the terms:
This is the same as the rotation matrix. Instead of a matrix multiplication, we can do regular multiplication in the complex plane.
Now, since all of our vectors are in Real, we need to convert them into Complex and then we can actually perform this multiplication. So we pair two numbers to form the Real+Complex part of the complex number. As such, the number of elements becomes n/2 so we only need to calculate frequencies for half the numbers.
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
Notice how we use float() to prevent numerical stability issues.
We also use a decay function of , so as our dimension value increases, the theta (how much to rotate) gets smaller.
![[Pasted image 20260423141546.png]]
We also divide the theta by the dimension to scale the spectrum and prevent it from shrinking to 0. The / dim normalizes the theta to be between 0 and 1.
Putting that into code:
>>> x_real = torch.tensor([[1.0, 0.0]])
>>> x_complex = torch.view_as_complex(x_real)
>>> theta = torch.tensor([torch.pi / 2])
>>> rotator = torch.polar(torch.ones_like(theta), theta)
>>> rotated_complex = x_complex * rotator
>>> rotated_real = torch.view_as_real(rotated_complex)
>>> print(f"Original Real Pair: {x_real.tolist()}")
... print(f"Complex Form: {x_complex.item()}")
... print(f"Rotated Complex: {rotated_complex.item()}")
... print(f"Final Real Pair: {rotated_real.tolist()}")
Original Real Pair: [[1.0, 0.0]]
Complex Form: (1+0j)
Rotated Complex: (-4.371138828673793e-08+1j) # roughly 0+1j
Final Real Pair: [[-4.371138828673793e-08, 1.0]] # roughly [0, 1]
Then we perform an outer product on this and the maximum sequence length. FYI, outer product is multiply everything with everything.
>>> a = torch.tensor([1, 2, 3])
>>> torch.outer(a, a)
tensor([[1, 2, 3],
[2, 4, 6],
[3, 6, 9]])
and final code:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
Now for the shape gymnastics.
head_dim = xq.shape[-1]
assert head_dim % 2 == 0, "head_dim should be even"
# xq = (b, sl, nh, hd) -> (b, sl, nh, n/2, 2) -> reinterpret to Complex
complex_shape = (*xq.shape[:-1], head_dim // 2, 2)
xq_ = torch.view_as_complex(xq.float().reshape(complex_shape))
xk_ = torch.view_as_complex(xk.float().reshape(complex_shape))
# reshape for broadcasting (1, seq_len, 1, head_dim // 2)
freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
seq_len = xq_.shape[1]
# if maximum context length is N and we're at token M where M < N,
# then we only need the first M frequencies.
freqs = freqs_cis[:, start_pos : start_pos + seq_len]
# rotate by multiplying in complex space
xq_out = torch.view_as_real(xq_ * freqs).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs).flatten(3)
# the output was (up)casted to fp32, needs to go back to original dtype
return xq_out.type_as(xq), xk_out.type_as(xk)