Recently, I've been exploring two-tower architectures after coming across some interesting implementations in Medium posts and Kaggle notebooks.

When experimenting with new architectures, I like to design a synthetic dataset. The design process forces me to think more about the data distributions the model is supposed to capture. I suppose this is somewhat analogous to unit testing in software engineering, although designing synthetic inputs typically is more complex than simple deterministic test inputs.

In this post, I'll detail my exploration of a two-tower model, covering this synthetic data creation, the model build, evaluation, and subsequent improvements using contrastive loss.

Objective

My objective was straightforward:

  • Build a simple two-tower model to rank query-document pairs.
  • Generate synthetic query-document pairs with controlled overlap.
  • Ensure the model effectively learns query-document relationships.
  • Experiment with why and how InfoNCE (contrastive loss) can improve retrieval results.

Code Snippets

  • Code for the MarginRankingLoss experiment here.
  • Code for the InfoNCE experiment here.

Synthetic Data Setup

To start, I created synthetic query-document pairs with explicit token overlap control. The explicit token overlap helps us ensure there is something meaningful for the model to learn:

def build_synthetic(
  n_queries=500, vocab=100, q_len=4, doc_len=16, overlap=0.6, seed=42
):
  rng = torch.Generator().manual_seed(seed)
  queries = torch.randint(0, vocab, (n_queries, q_len), generator=rng)
  docs_pos = torch.randint(0, vocab, (n_queries, doc_len), generator=rng)

  # Copy tokens from queries to documents with the specified overlap
  for i in range(n_queries):
    num_copy = int(overlap * q_len)
    copy_idx = torch.randperm(q_len, generator=rng)[:num_copy]
    docs_pos[i, :num_copy] = queries[i, copy_idx]

  # Create negative samples by rolling the positive documents
  docs_neg = docs_pos.roll(shifts=1, dims=0)
  return queries, docs_pos, docs_neg

The overlap parameter also controls the difficulty of the ranking task.

I generated negative samples by shifting the positive documents array by one position (docs_neg = docs_pos.roll(shifts=1, dims=0)).

Specifically, the positive document originally paired with query i (d_pos_i) becomes the negative sample paired with query i+1 (d_neg_{i+1}). So for the query and positive document at index 4, the negative document at index 5 is the same as the query and positive document at index 5.

Since negative samples are actual documents from the dataset, shifted by one position, the ranking task remains challenging given the other tokens came from the same distribution.

Model: Two-Tower Setup

For the ranking task, I used a two-tower neural architecture, consisting of embedding layers, mean pooling, linear projections, and cosine similarity scoring:

class TwoTower(nn.Module):
  def __init__(self, vocab, emb_dim=72, proj_dim=72):
    super().__init__()
    self.embedding = nn.Embedding(vocab, emb_dim)
    self.proj = nn.Linear(emb_dim, proj_dim, bias=False)

  def encode(self, toks):
    x = self.embedding(toks).mean(dim=1)
    return normalize(self.proj(x), dim=-1)

  def forward(self, q, d):
    qv, dv = self.encode(q), self.encode(d)
    return (qv * dv).sum(dim=-1)

Encoding

Each tower independently encodes inputs into fixed-size vectors:

  1. Mean Pooling: After retrieving token embeddings (self.embedding(toks)), which results in a tensor of shape (batch_size, sequence_length, embedding_dimension), the mean(dim=1) operation averages the embeddings across the sequence_length dimension.

The mean(dim=1) mean pooling operation reduces the dimensionality of the sequence, while still capturing the overall semantic meaning of the input. We're essentially saying "my representation is the centroid of the token vectors." It's fast but does ignore the order of the tokens.

  1. L2 Normalization: The mean-pooled vector is then passed through a linear projection layer (self.proj(x)). The output of this projection is normalized using normalize(..., dim=-1). This scales each vector to have a unit L2 norm (length of 1).

The normalization step ensures vectors have unit length, turning the dot product into a cosine similarity, ideal for semantic similarity measurement.

Forward Pass

The forward method computes cosine similarity, effectively ranking documents by relevance:

return (qv * dv).sum(dim=-1)

Training Loop

The initial model uses MarginRankingLoss to distinguish positive from negative samples:

loss_fn = nn.MarginRankingLoss(margin=0.25)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

for epoch in range(epochs):
  model.train()
  total_loss = 0.0

  for batch in loader:
    pos_scores = model(batch['q'], batch['d_pos'])
    neg_scores = model(batch['q'], batch['d_neg'])

    target = torch.ones_like(pos_scores).to(device)
    loss = loss_fn(pos_scores, neg_scores, target)

    opt.zero_grad()
    loss.backward()
    opt.step()

    total_loss += loss.item()

  avg_loss = total_loss / len(loader)

At the end of training, these two plots can help us understand if the model learned anything:

image

  • In the chart to the left, the pos scores represent the cosine similarity between positive docs and queries and the neg scores represent the consine similarity between the negative docs and queries.

  • The histogram on the right shows the distribution of the score difference (positive_score - negative_score) between relevant (positive) and irrelevant (negative) documents for the same query. We can see that most score differences are positive and clustered to the right of the target margin line (0.25), indicating the model has effectively learned to rank positive documents significantly higher than negative ones, as intended by the margin loss objective.

Evaluating with Recall@k

I also computed Recall@k across the entire document corpus. This approach measures if the correct document appears within the top-k results for each query. In my evaluation I used 10 for k:

@torch.no_grad()
def recall_at_k_corpus(model, q, docs, k=10, batch_size_eval=256):
  doc_vecs = []
  for i in range(0, len(docs), batch_size_eval):
    batch_docs = docs[i : i + batch_size_eval]
    doc_vecs.append(model.encode(batch_docs))
  doc_vecs = torch.cat(doc_vecs, 0)
  doc_mat = doc_vecs.t()

  hits = 0
  for i in range(0, len(q), batch_size_eval):
    batch_q = q[i : i + batch_size_eval]
    qv = model.encode(batch_q)
    sim = qv @ doc_mat
    topk_indices = sim.topk(k, dim=1).indices

    target_indices = torch.arange(i, i + qv.size(0), device=q.device).unsqueeze(1)
    hits += (topk_indices == target_indices).any(dim=1).sum().item()

  recall = hits / len(q)
  return recall

For this model using MarginRankingLoss, I got a Recall@10 of ~51% with these key parameters for the synthetic data:

n_queries = 500
vocab = 50
q_len = 16
doc_len = 48
overlap = 0.8
seed = 1337
batch_size = 16
margin = 0.25
lr = 3e-4
epochs = 10
emb_dim = 48
proj_dim = 72

Improving Performance with InfoNCE

The effectiveness of Margin Ranking Loss is limited, especially with simple bag-of-words encoders since the model only learns to distinguish the positive document from one random negative at a time.

Seems a more powerful approach for training retrieval models is InfoNCE (Noise Contrastive Estimation), often implemented using nn.CrossEntropyLoss. The core idea of InfoNCE is to make the embedding of a query qiq_i more similar to its corresponding positive document di+d_i^+ than to all other documents in the batch. So basically, we end up creating negative sample automatically for each query-document pair by treating all other documents as negative samples.

Objective: Make the embedding of each query qiq_i closest to its own positive di+d_i^+ and far from all others dj+d_j^+ (for jij≠i). So the key code section is in the forward pass:

loss_fn = nn.CrossEntropyLoss()
...

for batch in loader:
  q_vec = model.encode(batch['q'])
  d_vec = model.encode(batch['d_pos'])
   # note, not negative docs required

  logits = q_vec @ d_vec.T
  labels = torch.arange(logits.size(0), device=logits.device)
  loss = loss_fn(logits, labels)

CrossEntropyLoss:

  • Applies softmax across each row.
  • Penalizes when the probability mass isn’t on the diagonal element.
  • Gradient pushes qiq_i closer to did_i and farther from other djd_j.

Results

For this model using MarginRankingLoss, I got a Recall@10 of ~58% with these key parameters for the synthetic data. The only difference in the parameters is that margin = 0.25 is no longer used:

n_queries = 500
vocab = 100
q_len = 16
doc_len = 60
overlap = 0.5
seed = 1337
batch_size = 16
lr = 3e-4
epochs = 10
emb_dim = 36
proj_dim = 72