Skip to main content
Open In Colab

Zero-shot classification as a linear classifier (CLIP)

This notebook demonstrates that zero-shot classification with CLIP can be interpreted as a linear classifier whose weights are generated from text prompts. We follow the notation used on the CLIP lecture page: \boldsymbol\ell for images, t\mathbf{t} for text, ff_\ell and ftf_t for the two encoders, and z=f()\mathbf{z}_\ell = f_\ell(\boldsymbol\ell), zt=ft(t)\mathbf{z}_t = f_t(\mathbf{t}) for the embeddings. You will:
  1. Derive the formulation
  2. Implement zero-shot classification
  3. Visualize the text classifier weights and the image in the shared embedding space

1. Derivation

CLIP learns two encoders (see the CLIP lecture for the contrastive training objective): z=f(),zt=ft(t)\mathbf{z}_\ell = f_\ell(\boldsymbol\ell), \quad \mathbf{z}_t = f_t(\mathbf{t}) where \boldsymbol\ell is an image, t\mathbf{t} is a text string, ff_\ell is the image encoder, and ftf_t is the text encoder. Both output unit vectors on the surface of the hypersphere Sdz\mathbb{S}^{d_z}. For zero-shot classification over classes y{a,b,c,}y \in \lbrace a, b, c, \ldots\rbrace , we write a prompt ty\mathbf{t}^y for each class and embed it: zty=ft(ty)\mathbf{z}_t^y = f_t(\mathbf{t}^y) Given a query image q\boldsymbol\ell^q with embedding zq=f(q)\mathbf{z}_\ell^q = f_\ell(\boldsymbol\ell^q), the prediction is the class whose text embedding has the largest dot product with the image embedding: y^=argmaxy  (zq)zty\hat{y} = \arg\max_y \; (\mathbf{z}_\ell^q)^\top \mathbf{z}_t^y This is exactly a linear classifier over zq\mathbf{z}_\ell^q whose weight vectors are the per-class text embeddings: y^=argmaxy  (zty)zq\hat{y} = \arg\max_y \; (\mathbf{z}_t^y)^\top \mathbf{z}_\ell^q
Each text prompt ty\mathbf{t}^y instantiates one classifier weight vector zty\mathbf{z}_t^y. There is no training — the classifier is built on the fly from language.
!pip install transformers torch torchvision pillow matplotlib scikit-learn
Requirement already satisfied: transformers in ./.venv/lib/python3.12/site-packages (5.5.0)
Requirement already satisfied: torch in ./.venv/lib/python3.12/site-packages (2.11.0)
Requirement already satisfied: torchvision in ./.venv/lib/python3.12/site-packages (0.26.0)
Requirement already satisfied: pillow in ./.venv/lib/python3.12/site-packages (12.2.0)
Requirement already satisfied: matplotlib in ./.venv/lib/python3.12/site-packages (3.10.8)
Requirement already satisfied: scikit-learn in ./.venv/lib/python3.12/site-packages (1.8.0)
Requirement already satisfied: huggingface-hub<2.0,>=1.5.0 in ./.venv/lib/python3.12/site-packages (from transformers) (1.9.0)
Requirement already satisfied: numpy>=1.17 in ./.venv/lib/python3.12/site-packages (from transformers) (2.4.4)
Requirement already satisfied: packaging>=20.0 in ./.venv/lib/python3.12/site-packages (from transformers) (26.0)
Requirement already satisfied: pyyaml>=5.1 in ./.venv/lib/python3.12/site-packages (from transformers) (6.0.3)
Requirement already satisfied: regex>=2025.10.22 in ./.venv/lib/python3.12/site-packages (from transformers) (2026.4.4)
Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in ./.venv/lib/python3.12/site-packages (from transformers) (0.22.2)
Requirement already satisfied: typer in ./.venv/lib/python3.12/site-packages (from transformers) (0.24.1)
Requirement already satisfied: safetensors>=0.4.3 in ./.venv/lib/python3.12/site-packages (from transformers) (0.7.0)
Requirement already satisfied: tqdm>=4.27 in ./.venv/lib/python3.12/site-packages (from transformers) (4.67.3)
Requirement already satisfied: filelock in ./.venv/lib/python3.12/site-packages (from torch) (3.25.2)
Requirement already satisfied: typing-extensions>=4.10.0 in ./.venv/lib/python3.12/site-packages (from torch) (4.15.0)
Requirement already satisfied: setuptools<82 in ./.venv/lib/python3.12/site-packages (from torch) (81.0.0)
Requirement already satisfied: sympy>=1.13.3 in ./.venv/lib/python3.12/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx>=2.5.1 in ./.venv/lib/python3.12/site-packages (from torch) (3.6.1)
Requirement already satisfied: jinja2 in ./.venv/lib/python3.12/site-packages (from torch) (3.1.6)
Requirement already satisfied: fsspec>=0.8.5 in ./.venv/lib/python3.12/site-packages (from torch) (2026.3.0)
Requirement already satisfied: cuda-toolkit==13.0.2 in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.2)
Requirement already satisfied: cuda-bindings<14,>=13.0.3 in ./.venv/lib/python3.12/site-packages (from torch) (13.2.0)
Requirement already satisfied: nvidia-cudnn-cu13==9.19.0.56 in ./.venv/lib/python3.12/site-packages (from torch) (9.19.0.56)
Requirement already satisfied: nvidia-cusparselt-cu13==0.8.0 in ./.venv/lib/python3.12/site-packages (from torch) (0.8.0)
Requirement already satisfied: nvidia-nccl-cu13==2.28.9 in ./.venv/lib/python3.12/site-packages (from torch) (2.28.9)
Requirement already satisfied: nvidia-nvshmem-cu13==3.4.5 in ./.venv/lib/python3.12/site-packages (from torch) (3.4.5)
Requirement already satisfied: triton==3.6.0 in ./.venv/lib/python3.12/site-packages (from torch) (3.6.0)
Requirement already satisfied: nvidia-cublas==13.1.0.3.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.1.0.3)
Requirement already satisfied: nvidia-cuda-runtime==13.0.96.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.96)
Requirement already satisfied: nvidia-cufft==12.0.0.61.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (12.0.0.61)
Requirement already satisfied: nvidia-cufile==1.15.1.6.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (1.15.1.6)
Requirement already satisfied: nvidia-cuda-cupti==13.0.85.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.85)
Requirement already satisfied: nvidia-curand==10.4.0.35.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (10.4.0.35)
Requirement already satisfied: nvidia-cusolver==12.0.4.66.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (12.0.4.66)
Requirement already satisfied: nvidia-cusparse==12.6.3.3.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (12.6.3.3)
Requirement already satisfied: nvidia-nvjitlink==13.0.88.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.88)
Requirement already satisfied: nvidia-cuda-nvrtc==13.0.88.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.88)
Requirement already satisfied: nvidia-nvtx==13.0.85.* in ./.venv/lib/python3.12/site-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.85)
Requirement already satisfied: contourpy>=1.0.1 in ./.venv/lib/python3.12/site-packages (from matplotlib) (1.3.3)
Requirement already satisfied: cycler>=0.10 in ./.venv/lib/python3.12/site-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in ./.venv/lib/python3.12/site-packages (from matplotlib) (4.62.1)
Requirement already satisfied: kiwisolver>=1.3.1 in ./.venv/lib/python3.12/site-packages (from matplotlib) (1.5.0)
Requirement already satisfied: pyparsing>=3 in ./.venv/lib/python3.12/site-packages (from matplotlib) (3.3.2)
Requirement already satisfied: python-dateutil>=2.7 in ./.venv/lib/python3.12/site-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: scipy>=1.10.0 in ./.venv/lib/python3.12/site-packages (from scikit-learn) (1.17.1)
Requirement already satisfied: joblib>=1.3.0 in ./.venv/lib/python3.12/site-packages (from scikit-learn) (1.5.3)
Requirement already satisfied: threadpoolctl>=3.2.0 in ./.venv/lib/python3.12/site-packages (from scikit-learn) (3.6.0)
Requirement already satisfied: cuda-pathfinder~=1.1 in ./.venv/lib/python3.12/site-packages (from cuda-bindings<14,>=13.0.3->torch) (1.5.1)
Requirement already satisfied: hf-xet<2.0.0,>=1.4.3 in ./.venv/lib/python3.12/site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (1.4.3)
Requirement already satisfied: httpx<1,>=0.23.0 in ./.venv/lib/python3.12/site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (0.28.1)
Requirement already satisfied: six>=1.5 in ./.venv/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./.venv/lib/python3.12/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in ./.venv/lib/python3.12/site-packages (from jinja2->torch) (3.0.3)
Requirement already satisfied: click>=8.2.1 in ./.venv/lib/python3.12/site-packages (from typer->transformers) (8.3.2)
Requirement already satisfied: shellingham>=1.3.0 in ./.venv/lib/python3.12/site-packages (from typer->transformers) (1.5.4)
Requirement already satisfied: rich>=12.3.0 in ./.venv/lib/python3.12/site-packages (from typer->transformers) (14.3.3)
Requirement already satisfied: annotated-doc>=0.0.2 in ./.venv/lib/python3.12/site-packages (from typer->transformers) (0.0.4)
Requirement already satisfied: anyio in ./.venv/lib/python3.12/site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (4.13.0)
Requirement already satisfied: certifi in ./.venv/lib/python3.12/site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (2026.2.25)
Requirement already satisfied: httpcore==1.* in ./.venv/lib/python3.12/site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (1.0.9)
Requirement already satisfied: idna in ./.venv/lib/python3.12/site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (3.11)
Requirement already satisfied: h11>=0.16 in ./.venv/lib/python3.12/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (0.16.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in ./.venv/lib/python3.12/site-packages (from rich>=12.3.0->typer->transformers) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in ./.venv/lib/python3.12/site-packages (from rich>=12.3.0->typer->transformers) (2.20.0)
Requirement already satisfied: mdurl~=0.1 in ./.venv/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer->transformers) (0.1.2)
import torch
from PIL import Image
import requests
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
import numpy as np
/home/pantelis.monogioudis/repos/eng-ai-agents/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

Loading weights: 100%|██████████| 398/398 [00:00<00:00, 30297.52it/s]

CLIPModel LOAD REPORT from: openai/clip-vit-base-patch32
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
vision_model.embeddings.position_ids | UNEXPECTED |  | 
text_model.embeddings.position_ids   | UNEXPECTED |  | 

Notes:
- UNEXPECTED:	can be ignored when loading from different task/architecture; not ok if you expect identical arch.
CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (position_embedding): Embedding(50, 768)
    )
    (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (visual_projection): Linear(in_features=768, out_features=512, bias=False)
  (text_projection): Linear(in_features=512, out_features=512, bias=False)
)

2. Load a query image q\boldsymbol\ell^q

url = "https://images.unsplash.com/photo-1518717758536-85ae29035b6d"
image = Image.open(requests.get(url, stream=True).raw)
plt.imshow(image)
plt.axis("off")
(np.float64(-0.5), np.float64(7374.5), np.float64(4918.5), np.float64(-0.5))
Output from cell 4

3. Define prompts ty\mathbf{t}^y (classifier weights)

labels = ["a dog", "a cat", "a car", "a plane"]
prompts = [f"a photo of {label}" for label in labels]

4. Compute the embeddings zq\mathbf{z}_\ell^q and {zty}\lbrace \mathbf{z}_t^y\rbrace

inputs = processor(text=prompts, images=image, return_tensors="pt", padding=True)

with torch.no_grad():
    outputs = model(**inputs)
    image_embeds = outputs.image_embeds
    text_embeds = outputs.text_embeds

# Normalize (important for cosine similarity)
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

5. Zero-shot classification

Apply the classification rule from Section 1: y^=argmaxy  (zq)zty\hat{y} = \arg\max_y \; (\mathbf{z}_\ell^q)^\top \mathbf{z}_t^y

Why raw softmax looks flat

CLIP cosine similarities live in a narrow range. Both encoders produce L2L_2-normalized embeddings on Sdz\mathbb{S}^{d_z}, and in practice the joint space is a tight cone, so raw dot products (zq)zty(\mathbf{z}_\ell^q)^\top \mathbf{z}_t^y typically sit between 0.18 and 0.30 even for good matches. A softmax applied directly to such close values produces a nearly uniform distribution, which makes the classifier look much less confident than it actually is. The fix is the learned temperature parameter τ\tau that CLIP was trained with (the same τ\tau that appears in the InfoNCE loss on the CLIP lecture page). In Stable-Baselines3’s HuggingFace wrapper this is stored as logit_scale = 1/τ ≈ 100. Inference must scale the similarities by 1/τ1/\tau before the softmax: p(yq)=softmaxy ⁣(1τ(zq)zty)p(y \mid \boldsymbol\ell^q) = \mathrm{softmax}_y\!\left(\frac{1}{\tau} \cdot (\mathbf{z}_\ell^q)^\top \mathbf{z}_t^y\right) HuggingFace’s CLIP model exposes this scaled value as outputs.logits_per_image, so you get the sharp classification output for free. The cell below compares all three: raw cosine similarity, naive softmax (wrong), temperature-scaled softmax (correct), and the HuggingFace convenience.
# Raw cosine similarity between image and each text prompt
similarity = (image_embeds @ text_embeds.T).squeeze(0)

# 1. Naive softmax — this is the common pitfall
naive_probs = similarity.softmax(dim=0)

# 2. Correctly scaled softmax using CLIP's learned temperature
#    `logit_scale` is a learned parameter (≈4.6 in log space, ≈100 in linear),
#    and CLIP was trained with contrastive loss at that scale.
logit_scale = model.logit_scale.exp().item()
scaled_probs = (similarity * logit_scale).softmax(dim=0)

# 3. HuggingFace convenience — logits_per_image already has the scale applied
hf_probs = outputs.logits_per_image.softmax(dim=-1).squeeze(0)

print(f"Learned logit_scale (temperature): {logit_scale:.2f}")
print()
print(f"{'label':<10} {'cos sim':>10} {'naive':>10} {'scaled':>10} {'HF':>10}")
print("-" * 55)
for i, label in enumerate(labels):
    print(
        f"{label:<10} "
        f"{similarity[i].item():>10.4f} "
        f"{naive_probs[i].item():>10.4f} "
        f"{scaled_probs[i].item():>10.4f} "
        f"{hf_probs[i].item():>10.4f}"
    )
Learned logit_scale (temperature): 100.00

label         cos sim      naive     scaled         HF
-------------------------------------------------------
a dog          0.2758     0.2664     0.9990     0.9990
a cat          0.2048     0.2481     0.0008     0.0008
a car          0.1866     0.2437     0.0001     0.0001
a plane        0.1788     0.2418     0.0001     0.0001

6. Visualizing the classifier weights and the image in the same space

Plotting the first 50 components of each text embedding side by side is not very informative — the embeddings live on a dzd_z-dimensional unit hypersphere and the raw coordinates have no intrinsic meaning. A better view is to project the weights into a low-dimensional subspace and see how they are laid out relative to each other and to the image embedding zq\mathbf{z}_\ell^q. Below we fit a 2-component PCA to the text embeddings {zty}\lbrace \mathbf{z}_t^y\rbrace and project both the text prompts and the image into the same plane. Each labeled point is one classifier weight; the red star is the image embedding zq\mathbf{z}_\ell^q. The closest label to the star is the zero-shot prediction y^\hat{y}.
from sklearn.decomposition import PCA

# Project the text classifier weights to 2D with PCA so we can see them
# arranged in the shared CLIP embedding space. Each point is one prompt,
# and semantically related prompts should cluster together.
weights = text_embeds.numpy()          # {z_t^y}
image_vec = image_embeds.numpy()       # z_l^q

pca = PCA(n_components=2)
pca.fit(weights)

proj_text = pca.transform(weights)
proj_image = pca.transform(image_vec)

fig, ax = plt.subplots(figsize=(7, 6))

# Text classifier weights z_t^y
ax.scatter(proj_text[:, 0], proj_text[:, 1], s=180, c="steelblue", edgecolors="black", zorder=3, label=r"text prompts $z_t^y$ (classifier weights)")
for i, label in enumerate(labels):
    ax.annotate(
        label,
        (proj_text[i, 0], proj_text[i, 1]),
        xytext=(8, 8),
        textcoords="offset points",
        fontsize=11,
    )

# Image embedding z_l^q in the same PCA basis
ax.scatter(proj_image[:, 0], proj_image[:, 1], s=260, c="crimson", marker="*", edgecolors="black", zorder=4, label=r"image embedding $z_\ell^q$")

ax.set_xlabel(f"PC1  ({pca.explained_variance_ratio_[0]:.1%} var)")
ax.set_ylabel(f"PC2  ({pca.explained_variance_ratio_[1]:.1%} var)")
ax.set_title("Classifier weights and query image in 2D (PCA of text embeddings)")
ax.axhline(0, color="gray", linewidth=0.5)
ax.axvline(0, color="gray", linewidth=0.5)
ax.legend(loc="best")
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nPCA variance explained: PC1={pca.explained_variance_ratio_[0]:.1%}, PC2={pca.explained_variance_ratio_[1]:.1%}")
print()
# Temperature-scaled logits and softmax probabilities — consistent with Section 5
logit_scale = model.logit_scale.exp().item()
raw_sims = (image_vec @ weights.T).flatten()
logits = logit_scale * raw_sims
probs = np.exp(logits - logits.max())
probs = probs / probs.sum()

print(f"Temperature-scaled classification (1/tau = {logit_scale:.2f}):")
print(f"{'label':<10} {'cos sim':>10} {'logit':>10} {'prob':>10}")
print("-" * 45)
for label, s, lg, p in zip(labels, raw_sims, logits, probs):
    print(f"{label:<10} {s:>10.4f} {lg:>10.2f} {p:>10.4f}")
print()
print(f"Predicted class: {labels[int(np.argmax(probs))]}")
Output from cell 8

PCA variance explained: PC1=50.8%, PC2=31.4%

Temperature-scaled classification (1/tau = 100.00):
label         cos sim      logit       prob
---------------------------------------------
a dog          0.2758      27.58     0.9990
a cat          0.2048      20.48     0.0008
a car          0.1866      18.66     0.0001
a plane        0.1788      17.88     0.0001

Predicted class: a dog

7. Conclusion

We demonstrated that:
  • Each class prompt ty\mathbf{t}^y produces a vector zty=ft(ty)\mathbf{z}_t^y = f_t(\mathbf{t}^y)
  • These vectors act as classifier weights
  • Zero-shot classification is linear classification in the shared CLIP embedding space
y^=argmaxy  (zty)zq\hat{y} = \arg\max_y \; (\mathbf{z}_t^y)^\top \mathbf{z}_\ell^q