Scientyfic World

How to Fix Hidden Layer Shape Mismatches in PyTorch

You know the error. It always shows up right when you think the forward pass is “basically wired up”. Or it fails later with something like a torch.cat dimension mismatch,...

Share:

Get an AI summary of this article

Illustration of a man with beard

You know the error. It always shows up right when you think the forward pass is “basically wired up”.

RuntimeError: mat1 and mat2 shapes cannot be multiplied (B x F) and (H x O)

Or it fails later with something like a torch.cat dimension mismatch, even though your inputs look fine.

This post is about the real failure mode: a PyTorch shape mismatch where your tensors silently drift from “batch” and “feature” dimensions into the wrong places—especially around Linear, Conv2d, flatten, and concatenation.

The broken version (and why it feels like it should work)

Here’s a minimal example that fails because the Linear layer’s input size is hard-coded to a wrong feature dimension.

import torch
import torch.nn as nn

class BadMLP(nn.Module):
    def __init__(self, in_features_hypothesis=10, hidden=16):
        super().__init__()
        # You guessed this in_features size.
        self.fc = nn.Linear(in_features_hypothesis, hidden)

    def forward(self, x):
        # x is (B, F) but F is not what you guessed.
        return self.fc(x)

model = BadMLP(in_features_hypothesis=10)
x = torch.randn(8, 12)  # B=8, F=12

out = model(x)  # <-- boom

You’ll get the classic mat1 and mat2 shapes cannot be multiplied. The message tells you what PyTorch tried: it saw an input with feature dimension 12, but your Linear expects 10.

This is the easy case. The harder one is when the mismatch happens after a few layers, so the error message points at the wrong line in your mental model.

First: separate “batch size” from “feature size”

The most common confusion I see is treating batch as if it were a feature dimension, or vice versa.

In PyTorch:

  • Linear expects input shaped (N, *, in_features) where the last dimension is in_features.
  • matmul (which backs Linear) multiplies the last dimension of your input against the in_features dimension in the weight matrix.
  • Conv2d treats channels as a specific dimension (input is typically (N, C, H, W)), and it will happily produce outputs whose spatial sizes depend on kernel/stride/padding.

So when you see “cannot be multiplied (B x F) and (H x O)”, the important bit is: PyTorch believes your last dimension is F, and your Linear believes it should be in_features = H (based on how you constructed the layer).

Now the version that breaks after concatenation

Concatenation is where shapes go to die. You can pass tensors that “look compatible” but are incompatible along the concatenation axis, or you can end up with a feature dimension that no longer matches the next Linear.

class BadConcatMLP(nn.Module):
    def __init__(self):
        super().__init__()
        # You think concatenation gives you 10 features total.
        self.fc = nn.Linear(10, 8)

    def forward(self, x1, x2):
        # x1: (B, 4), x2: (B, 6)  => cat along features => (B, 10)
        # That part is fine.
        z = torch.cat([x1, x2], dim=1)
        return self.fc(z)

model = BadConcatMLP()
x1 = torch.randn(5, 4)
x2 = torch.randn(5, 7)  # <-- changed feature count (now cat => 11)

out = model(x1, x2)  # <-- matmul shape mismatch

Change x2 from 6 to 7 feature channels and your concatenated tensor becomes (B, 11) instead of (B, 10). Linear(10, 8) can’t multiply with an input whose last dimension is 11.

This is why developers end up adding random reshapes or resizing images “until it works”. You don’t need luck. You need shape tracing.

PyTorch debugging: trace shapes like a circuit

When dimension errors are hidden inside your model, print shapes right where the assumptions change. Usually that means just before each layer that depends on a specific dimension:

  • Linear: right before it
  • Conv2d: right after it (to capture new H/W)
  • flatten: right after you flatten (to confirm the feature count)
  • cat/stack: immediately after to confirm the concatenated feature dimension

Here’s a small “trace everything that matters” rewrite of the concat example.

class DebugConcatMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 8)

    def forward(self, x1, x2):
        print("x1:", x1.shape, "x2:", x2.shape)
        z = torch.cat([x1, x2], dim=1)
        print("cat z:", z.shape)

        # Assert the part you care about.
        assert z.shape[1] == self.fc.in_features, (
            f"Expected in_features={self.fc.in_features}, got {z.shape[1]}"
        )

        y = self.fc(z)
        print("out:", y.shape)
        return y

The assert turns a cryptic runtime error into an immediate, local failure with the actual numbers. That’s the whole game.

Why Linear / Conv2d / flatten shift dimensions

Linear

nn.Linear(in_features, out_features) uses the last dimension of the input as in_features.

If your tensor is (B, F), it works if and only if F == in_features.

If your tensor is accidentally (F, B), you’ll get the same matmul mismatch, but the bug is now “dimension order”, not “feature count”.

Conv2d

nn.Conv2d operates on (N, C, H, W). It changes:

  • C (to out_channels)
  • H and W based on kernel/stride/padding

So if you later build Linear expecting a fixed flattened size, any change in input image size (or even just a different crop) can break it.

flatten

Common pattern: z = conv(x); z = z.flatten(1)

flatten(1) means “keep batch dim, flatten everything else into one feature vector”. The resulting feature dimension is:

C_out * H_out * W_out

If H/W change, your Linear input size changes. That’s the dimension mismatch you’ll see next.

cat

torch.cat(tensors, dim=...) requires that all non-concat dimensions match exactly.

For feature concatenation with MLPs, it’s usually dim=1 for (B, F) tensors. After concatenation, your last feature dimension becomes:

F_total = F1 + F2 + ...

If you built Linear(in_features=F_hypothesis), then that hypothesis must match reality.

Edge cases that make shape bugs feel “random”

  • Variable-length inputs: If a batch contains different sequence lengths or different spatial sizes (common with padding/cropping), your flatten size can change between iterations. You can’t hard-code in_features.
  • Changing image sizes: Resize/crop transforms that don’t exactly match what you trained with will change H_out/W_out and break the flatten → Linear boundary.
  • Silent swaps (batch vs feature): If you accidentally use x = x.transpose(0, 1) (or you forgot a permute after dataset preprocessing), you’ll feed Linear with (F, B). The error won’t say “you swapped dims”. It’ll say “matmul can’t multiply those shapes”.
  • cat axis confusion: Concatenating along the wrong dimension might still run if sizes match, but it changes your feature layout. The mismatch may show up only in the next layer.

A practical fix: stop guessing Linear input sizes

You have two options when building models that include conv + flatten or concatenation:

  1. Compute the feature size from a known input (often best for fixed-size inputs).
  2. Validate and fail fast with asserts so you catch mismatch near the boundary (often best for variable-size inputs).

Here’s the “compute from a sample” approach for conv → flatten → linear.

class GoodConvMLP(nn.Module):
    def __init__(self, in_channels=3, num_classes=10, img_size=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 8, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        # Determine flattened size using a dummy forward.
        with torch.no_grad():
            x = torch.zeros(1, in_channels, img_size, img_size)
            y = self.conv(x)
            flat = y.flatten(1).shape[1]

        self.fc = nn.Linear(flat, num_classes)

    def forward(self, x):
        z = self.conv(x)
        z = z.flatten(1)
        # Sanity check for unexpected sizes.
        assert z.shape[1] == self.fc.in_features, (
            f"Linear expects {self.fc.in_features} features, got {z.shape[1]}"
        )
        return self.fc(z)

This is not magic. It just replaces “guessing” with “measuring” using the same layer stack. If your runtime input sizes differ from img_size, the assert will tell you immediately.

Bonus: a quick shape-debug checklist

When you hit a PyTorch shape mismatch, do this before you change random code:

  1. Locate the boundary: The line mentioned in the error is usually the first layer that can’t reconcile shapes (often Linear or cat).
  2. Print shapes right before it: one line per tensor: print(name, t.shape).
  3. Assert the assumption: For Linear, assert x.shape[-1] == self.fc.in_features.
  4. Check axis choices: For cat, confirm dim matches your intended feature axis (e.g., dim=1 for (B, F)).
  5. Watch for flatten: confirm flatten(1) keeps batch intact and produces the feature count your Linear expects.

Here’s the reframing that usually fixes the “hidden” part: most shape bugs aren’t random. They’re consistent mistakes about which dimension means what. Once you trace shapes at the boundaries, the mismatch stops being mysterious and becomes a simple dimension-order or feature-count correction.

If you want, paste your failing layer + the input tensor shapes you think you’re using (just the torch.Size values). I can point out the exact mismatch boundary.

Snehasish Konger
Developed @scientyficworld.org | Technical writer @Nected | Content Developer
Connect with Snehasish Konger

On This page

Take a Pause with Intervals

A Sunday letter on building, writing, and thinking deeper as a developer — short, honest, and worth your time.

Snehasish Konger profile photo

"Hey there — I'm Snehasish. Hope this post saved you some head-scratching time! I've spent years turning technical chaos into clarity, and I'm here to be your guide through the maze of modern tech. Stick around for more lightbulb moments — we're just getting started."

Related Posts