Exploring Two-Tower Ranking Models
Recently, I've been exploring two-tower architectures after coming across some interesting implementations in Medium posts and Kaggle notebooks.
When experimenting with new architectures, I like to design a synthetic dataset. The design process forces me to think more about the data distributions the model is supposed to capture. I suppose this is somewhat analogous to unit testing in software engineering, although designing synthetic inputs typically is more complex than simple deterministic test inputs.
In this post, I'll detail my exploration of a two-tower model, covering this synthetic data creation, the model build, evaluation, and subsequent improvements using contrastive loss.
Objective
My objective was straightforward:
- Build a simple two-tower model to rank query-document pairs.
- Generate synthetic query-document pairs with controlled overlap.
- Ensure the model effectively learns query-document relationships.
- Experiment with why and how InfoNCE (contrastive loss) can improve retrieval results.
Code Snippets
Synthetic Data Setup
To start, I created synthetic query-document pairs with explicit token overlap control. The explicit token overlap helps us ensure there is something meaningful for the model to learn:
def build_synthetic(
n_queries=500, vocab=100, q_len=4, doc_len=16, overlap=0.6, seed=42
):
rng = torch.Generator().manual_seed(seed)
queries = torch.randint(0, vocab, (n_queries, q_len), generator=rng)
docs_pos = torch.randint(0, vocab, (n_queries, doc_len), generator=rng)
# Copy tokens from queries to documents with the specified overlap
for i in range(n_queries):
num_copy = int(overlap * q_len)
copy_idx = torch.randperm(q_len, generator=rng)[:num_copy]
docs_pos[i, :num_copy] = queries[i, copy_idx]
# Create negative samples by rolling the positive documents
docs_neg = docs_pos.roll(shifts=1, dims=0)
return queries, docs_pos, docs_neg
The overlap
parameter also controls the difficulty of the ranking task.
I generated negative samples by shifting the positive documents array by one position (docs_neg = docs_pos.roll(shifts=1, dims=0)
).
Specifically, the positive document originally paired with query i
(d_pos_i
) becomes the negative sample paired with query i+1
(d_neg_{i+1}
). So for the query and positive document at index 4, the negative document at index 5 is the same as the query and positive document at index 5.
Since negative samples are actual documents from the dataset, shifted by one position, the ranking task remains challenging given the other tokens came from the same distribution.
Model: Two-Tower Setup
For the ranking task, I used a two-tower neural architecture, consisting of embedding layers, mean pooling, linear projections, and cosine similarity scoring:
class TwoTower(nn.Module):
def __init__(self, vocab, emb_dim=72, proj_dim=72):
super().__init__()
self.embedding = nn.Embedding(vocab, emb_dim)
self.proj = nn.Linear(emb_dim, proj_dim, bias=False)
def encode(self, toks):
x = self.embedding(toks).mean(dim=1)
return normalize(self.proj(x), dim=-1)
def forward(self, q, d):
qv, dv = self.encode(q), self.encode(d)
return (qv * dv).sum(dim=-1)
Encoding
Each tower independently encodes inputs into fixed-size vectors:
- Mean Pooling: After retrieving token embeddings (
self.embedding(toks)
), which results in a tensor of shape(batch_size, sequence_length, embedding_dimension)
, themean(dim=1)
operation averages the embeddings across thesequence_length
dimension.
The mean(dim=1)
mean pooling operation reduces the dimensionality of the sequence, while still capturing the overall semantic meaning of the input. We're essentially saying "my representation is the centroid of the token vectors." It's fast but does ignore the order of the tokens.
- L2 Normalization: The mean-pooled vector is then passed through a linear projection layer (
self.proj(x)
). The output of this projection is normalized usingnormalize(..., dim=-1)
. This scales each vector to have a unit L2 norm (length of 1).
The normalization step ensures vectors have unit length, turning the dot product into a cosine similarity, ideal for semantic similarity measurement.
Forward Pass
The forward method computes cosine similarity, effectively ranking documents by relevance:
return (qv * dv).sum(dim=-1)
Training Loop
The initial model uses MarginRankingLoss to distinguish positive from negative samples:
loss_fn = nn.MarginRankingLoss(margin=0.25)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
for epoch in range(epochs):
model.train()
total_loss = 0.0
for batch in loader:
pos_scores = model(batch['q'], batch['d_pos'])
neg_scores = model(batch['q'], batch['d_neg'])
target = torch.ones_like(pos_scores).to(device)
loss = loss_fn(pos_scores, neg_scores, target)
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.item()
avg_loss = total_loss / len(loader)
At the end of training, these two plots can help us understand if the model learned anything:
-
In the chart to the left, the
pos scores
represent the cosine similarity between positive docs and queries and theneg scores
represent the consine similarity between the negative docs and queries. -
The histogram on the right shows the distribution of the score difference (positive_score - negative_score) between relevant (positive) and irrelevant (negative) documents for the same query. We can see that most score differences are positive and clustered to the right of the target margin line (0.25), indicating the model has effectively learned to rank positive documents significantly higher than negative ones, as intended by the margin loss objective.
Evaluating with Recall@k
I also computed Recall@k across the entire document corpus. This approach measures if the correct document appears within the top-k results for each query. In my evaluation I used 10 for k:
@torch.no_grad()
def recall_at_k_corpus(model, q, docs, k=10, batch_size_eval=256):
doc_vecs = []
for i in range(0, len(docs), batch_size_eval):
batch_docs = docs[i : i + batch_size_eval]
doc_vecs.append(model.encode(batch_docs))
doc_vecs = torch.cat(doc_vecs, 0)
doc_mat = doc_vecs.t()
hits = 0
for i in range(0, len(q), batch_size_eval):
batch_q = q[i : i + batch_size_eval]
qv = model.encode(batch_q)
sim = qv @ doc_mat
topk_indices = sim.topk(k, dim=1).indices
target_indices = torch.arange(i, i + qv.size(0), device=q.device).unsqueeze(1)
hits += (topk_indices == target_indices).any(dim=1).sum().item()
recall = hits / len(q)
return recall
For this model using MarginRankingLoss, I got a Recall@10 of ~51% with these key parameters for the synthetic data:
n_queries = 500
vocab = 50
q_len = 16
doc_len = 48
overlap = 0.8
seed = 1337
batch_size = 16
margin = 0.25
lr = 3e-4
epochs = 10
emb_dim = 48
proj_dim = 72
Improving Performance with InfoNCE
The effectiveness of Margin Ranking Loss is limited, especially with simple bag-of-words encoders since the model only learns to distinguish the positive document from one random negative at a time.
Seems a more powerful approach for training retrieval models is InfoNCE (Noise Contrastive Estimation), often implemented using nn.CrossEntropyLoss
. The core idea of InfoNCE is to make the embedding of a query more similar to its corresponding positive document than to all other documents in the batch. So basically, we end up creating negative sample automatically for each query-document pair by treating all other documents as negative samples.
Objective: Make the embedding of each query closest to its own positive and far from all others (for ). So the key code section is in the forward pass:
loss_fn = nn.CrossEntropyLoss()
...
for batch in loader:
q_vec = model.encode(batch['q'])
d_vec = model.encode(batch['d_pos'])
# note, not negative docs required
logits = q_vec @ d_vec.T
labels = torch.arange(logits.size(0), device=logits.device)
loss = loss_fn(logits, labels)
CrossEntropyLoss
:
- Applies softmax across each row.
- Penalizes when the probability mass isn’t on the diagonal element.
- Gradient pushes closer to and farther from other .
Results
For this model using MarginRankingLoss, I got a Recall@10 of ~58% with these key parameters for the synthetic data. The only difference in the parameters is that margin = 0.25
is no longer used:
n_queries = 500
vocab = 100
q_len = 16
doc_len = 60
overlap = 0.5
seed = 1337
batch_size = 16
lr = 3e-4
epochs = 10
emb_dim = 36
proj_dim = 72