Rotary Position Embeddings
RoPE or Rotary Position Embedding is a technique we implement that allows models to learn the positions of tokens. This is sort of a direct upgrade over sinusoidal embeddings.
The main idea remains the same. Across the sequence length, we want to reason about how far apart the tokens are (relative positioning) and their positions in the entire sequence (absolute positioning).
For this, we employ the same workings as sinusoidal embeddings but instead of sine and cosine as an additional step, we opt to directly inject the positions by rotating the token vectors while preserving the magnitudes.
Rotation in the Complex Plane
This can be done in the complex plane much more easily than in the real plane. The real plane would have a 2D rotation matrix where each point would be matrix multiplied with the matrix which would rotate the thing.
Now I don’t really like doing it this way so I prefer using the complex plane math using Euler’s formula:
In the complex plane, if , you are at (1, 0), if , you are at (i.e. coordinate (0, 1)). When you multiply a point in the complex plane, you don’t change its magnitude, only its direction. To prove this, we can expand the terms:
This is the same as the rotation matrix. Instead of a matrix multiplication, we can do regular multiplication in the complex plane.
Converting Real Vectors into Complex Numbers
Now, since all of our vectors are real-valued, we need to convert them into complex numbers before performing this multiplication. So we pair two numbers to form the real and imaginary parts of a complex number. As such, the number of complex elements becomes dim/2 so we only need to calculate frequencies for half the dimensions.
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
Notice how we use float() to ensure numerical stability.
We also use an exponential decay: as the dimension index increases, theta (how much to rotate) gets smaller.
We also normalize the dimension indices by dividing by dim to scale the spectrum and prevent it from shrinking to 0. The / dim keeps the exponent between 0 and 1.
A Minimal Rotation Example
Putting that into code:
>>> x_real = torch.tensor([[1.0, 0.0]])
>>> x_complex = torch.view_as_complex(x_real)
>>> theta = torch.tensor([torch.pi / 2])
>>> rotator = torch.polar(torch.ones_like(theta), theta)
>>> rotated_complex = x_complex * rotator
>>> rotated_real = torch.view_as_real(rotated_complex)
>>> print(f"Original Real Pair: {x_real.tolist()}")
... print(f"Complex Form: {x_complex.item()}")
... print(f"Rotated Complex: {rotated_complex.item()}")
... print(f"Final Real Pair: {rotated_real.tolist()}")
Original Real Pair: [[1.0, 0.0]]
Complex Form: (1+0j)
Rotated Complex: (-4.371138828673793e-08+1j) # roughly 0+1j
Final Real Pair: [[-4.371138828673793e-08, 1.0]] # roughly [0, 1]
Computing Frequencies Across Positions
Then we compute the outer product of the frequencies with the position indices up to the maximum sequence length. FYI, an outer product multiplies every pair of elements from two vectors.
>>> a = torch.tensor([1, 2, 3])
>>> torch.outer(a, a)
tensor([[1, 2, 3],
[2, 4, 6],
[3, 6, 9]])
And the final code:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
Shape Gymnastics
Now for the shape gymnastics.
head_dim = xq.shape[-1]
assert head_dim % 2 == 0, "head_dim should be even"
# xq = (b, sl, nh, hd) -> (b, sl, nh, n/2, 2) -> reinterpret to Complex
complex_shape = (*xq.shape[:-1], head_dim // 2, 2)
xq_ = torch.view_as_complex(xq.float().reshape(complex_shape))
xk_ = torch.view_as_complex(xk.float().reshape(complex_shape))
# reshape for broadcasting (1, seq_len, 1, head_dim // 2)
freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
seq_len = xq_.shape[1]
# if maximum context length is N and we're at token M where M < N,
# then we only need the first M frequencies.
freqs = freqs_cis[:, start_pos : start_pos + seq_len]
# rotate by multiplying in complex space
xq_out = torch.view_as_real(xq_ * freqs).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs).flatten(3)
# the output was cast to fp32, needs to go back to original dtype
return xq_out.type_as(xq), xk_out.type_as(xk)