Marginal Musings

Super-Resolution Research, Part 6: Implementing SRCNN from Scratch

Building the original super-resolution CNN from Dong et al. (2014) end-to-end -- dataset creation with sliding windows, the 3-layer architecture, per-layer learning rates, and why my PSNR numbers didn't match the paper's.

Author
Shlomo Stept
Published
Updated
Note
Originally written 2022-08

Implementing SRCNN from Scratch in PyTorch: The Paper vs. The Reality

SRCNN — Super-Resolution Convolutional Neural Network — was the first paper to apply deep learning to single image super-resolution. Published by Dong et al. in 2014, it demonstrated that a simple three-layer CNN could learn the mapping from low-resolution to high-resolution images, outperforming decades of hand-engineered methods. The architecture is almost comically simple by modern standards: three convolutions, two ReLUs, and about 69,000 parameters.

I implemented it from scratch during my CS 497 independent study in summer 2022 as the foundation for understanding the super-resolution pipeline end-to-end. What I expected was a weekend project. What I got was three months of debugging preprocessing pipelines, discovering that resizing libraries disagree by up to 21 dB , finding a bug in clean-fid , and building a SIFT-based dataset alignment system . The SRCNN model itself was the easy part. Everything around it was the hard part.

From MATLAB to PyTorch: Why Reimplementation Matters

The original SRCNN paper was implemented in MATLAB with Caffe — the standard stack for computer vision research in 2014. MATLAB was the default for academic image processing at the time because it shipped with a well-tested Image Processing Toolbox, and Caffe was the deep learning framework that most vision labs used before TensorFlow and PyTorch existed. If you wanted to reproduce a 2014 CV paper exactly as the authors did, you would need MATLAB, Caffe, and the specific versions of both that the authors happened to be running.

I chose to reimplement in PyTorch because I wanted to understand the super-resolution pipeline, not just run it. There is a difference between downloading someone else’s MATLAB code, calling imresize and caffe.train, and getting numbers that match the paper — and actually building the dataset pipeline, the model, the training loop, and the evaluation from scratch in a framework where none of those decisions are made for you. The first approach tells you that SRCNN works. The second approach tells you why it works, and more importantly, it tells you all the places where “just implement the paper” turns out to require dozens of decisions that the paper never specifies.

(This distinction became the animating question of my entire summer research. The SRCNN paper says “bicubic downsampling.” It does not say which library, which filter parameters, which anti-aliasing settings, or which coordinate mapping convention. These choices matter. They matter enough to shift PSNR by over 20 dB. And I only discovered this because I was building from scratch in PyTorch rather than running the authors’ MATLAB code.)

The MATLAB-to-PyTorch translation also forced me to confront a fact about reproducibility in image processing that I think most people in the field know but rarely talk about: MATLAB’s imresize and Python’s image processing libraries do not produce the same output for the same input. Not approximately different — measurably, significantly different. The interpolation kernels are different, the anti-aliasing defaults are different, the boundary handling is different, and the coordinate mapping conventions are different. When the original SRCNN paper reports PSNR numbers, those numbers are entangled with MATLAB’s specific implementation of bicubic interpolation in ways that make exact reproduction in Python essentially impossible. You can get close. You cannot match.

The Architecture: Three Layers, That’s It

The SRCNN paper describes three conceptual stages, each implemented as a single convolutional layer:

  1. Patch extraction and representation: Conv2d(3, 64, 9x9) + ReLU — extracts overlapping patches from the input and represents each as a 64-dimensional feature vector
  2. Non-linear mapping: Conv2d(64, 32, 5x5) + ReLU — maps the 64-dim patch representation to a 32-dim representation of the high-resolution patch
  3. Reconstruction: Conv2d(32, 3, 5x5) — aggregates the high-resolution patch predictions into the final output image

No padding is used, so the spatial dimensions shrink at each layer: a 33x33 input produces a 17x17 output (33 - 8 - 4 - 4 = 17).

class SRCNN(torch.nn.Module):

    def __init__(self):
        super(SRCNN, self).__init__()
        self.layer1 = torch.nn.Conv2d(in_channels=3, out_channels=64,
                                       kernel_size=(9, 9), stride=(1, 1))
        self.layer2 = torch.nn.Conv2d(in_channels=64, out_channels=32,
                                       kernel_size=(5, 5), stride=(1, 1))
        self.layer3 = torch.nn.Conv2d(in_channels=32, out_channels=3,
                                       kernel_size=(5, 5), stride=(1, 1))
        self.relu = torch.nn.ReLU(True)
        self._initialize_weights()

    def forward(self, x):
        f1 = self.relu(self.layer1(x))
        f2 = self.relu(self.layer2(f1))
        f3 = self.layer3(f2)
        return f3

    def _initialize_weights(self):
        for layer in [self.layer1, self.layer2]:
            torch.nn.init.kaiming_normal_(layer.weight.data, nonlinearity='relu')
            torch.nn.init.zeros_(layer.bias.data)
        torch.nn.init.kaiming_normal_(self.layer3.weight.data, nonlinearity='relu')
        torch.nn.init.zeros_(self.layer3.bias.data)

The original paper used Gaussian initialization with mean 0 and standard deviation 0.001. In my v2 training run, I switched to Kaiming/He initialization, which is designed for ReLU networks and provides variance scaling that accounts for the non-linearity. This was one of three training bugs I fixed (more on that below) and contributed to faster, more stable convergence.

The total parameter count is modest:

LayerParamsCalculation
Layer 115,6163 x 64 x 9 x 9 + 64
Layer 251,23264 x 32 x 5 x 5 + 32
Layer 32,40332 x 3 x 5 x 5 + 3
Total69,251

Dataset Creation: The Sliding Window Pipeline

This is where the real engineering lives. SRCNN does not train on full images — it trains on small patch pairs. For each full-resolution training image, the pipeline:

  1. Mod-crop: Ensure dimensions are divisible by the scale factor (3x)
  2. Extract 33x33 HR patches using a sliding window with stride 14
  3. Downsample each patch by 3x using LANCZOS, then upsample back to 33x33 using bicubic — this creates the LR input
  4. Crop the HR patch center to 17x17 — this is the target that matches the model’s output size
def create_dataset(f_path, output_folder, fsub=33, stride_size=14, scale=3):
    list_of_images = os.listdir(f_path)

    for im_num, file in enumerate(list_of_images):
        full_im_hr = Image.open(os.path.join(f_path, file))
        big_im_hr = mod_crop(full_im_hr, scale)
        im_width, im_height = big_im_hr.size

        if im_width < fsub or im_height < fsub:
            continue

        sub_im_num = 1
        start_height = 0
        for end_height in range(fsub, im_height + 1, stride_size):
            start_width = 0
            for end_width in range(fsub, im_width + 1, stride_size):
                # Extract HR sub-patch
                hr_sub = big_im_hr.crop((start_width, start_height,
                                          end_width, end_height))

                # Downsample 1/3 with LANCZOS, upsample back with bicubic
                sub_w = end_width - start_width
                sub_h = end_height - start_height
                lr_small = hr_sub.resize((sub_w // 3, sub_h // 3), Image.LANCZOS)
                lr_upscaled = lr_small.resize((sub_w, sub_h), Image.BICUBIC)

                # Center-crop HR to 17x17 (model output size)
                center_hr = get_center_hr(hr_sub, 33, 9, 5, 5)

                # Save pair
                lr_upscaled.save(f"{LR_folder}/Y{im_num}_sub{sub_im_num}_LR.bmp")
                center_hr.save(f"{HR_folder}/X{im_num}_sub{sub_im_num}_HR.bmp")

                sub_im_num += 1
                start_width += stride_size
            start_height += stride_size

The mod-crop utility ensures clean division:

def mod_crop(full_image, mod):
    orig_w, orig_h = full_image.size
    new_w = orig_w - (orig_w % mod)
    new_h = orig_h - (orig_h % mod)
    return full_image.crop((0, 0, new_w, new_h))

Using the T91 dataset (91 training images), this produces thousands of patch pairs. With stride 14 and patch size 33, a single 256x256 image yields approximately 256 patches. Across 91 images, that is roughly 20,000-30,000 training pairs, depending on the original image sizes.

Why LANCZOS for Downsampling: The Library Disagreement Problem

This was a deliberate choice, and the reasoning behind it ended up being more consequential than the choice itself. The original SRCNN paper used MATLAB’s imresize with bicubic interpolation for downsampling. MATLAB’s implementation applies an anti-aliasing filter by default during downsampling, which prevents aliasing artifacts that would otherwise corrupt the low-resolution images.

PIL’s Image.BICUBIC does not apply anti-aliasing during downsampling. The closest equivalent in PIL is Image.LANCZOS, which applies a Lanczos windowed sinc filter that provides similar anti-aliasing properties. This is noted in the code comments:

"

“Pillow does not allow for both Antialiasing and Bicubic interpolation to be applied at the same time. ANTIALIAS was replaced by LANCZOS. Therefore this implementation uses the LANCZOS Downsampling Filter to replicate the antialiasing properties while still emulating the core style of bicubic interpolation.”

What I did not realize at the time — and what became the subject of an entire separate investigation — is that this choice is one instance of a much larger problem. Every image processing library implements resizing differently: PIL, OpenCV, scikit-image, TensorFlow, PyTorch’s torchvision, and MATLAB all have their own interpolation kernels, coordinate mapping conventions, and boundary handling rules. When two papers both report using “bicubic downsampling,” they may be computing substantially different images depending on which library they used. I documented this in the resizing disagreement study , where I measured PSNR differences of over 21 dB between different libraries performing ostensibly the same resize operation on the same input.

The practical consequence for SRCNN reimplementation is that your training data is a function of your resizing library. Train with PIL’s LANCZOS and evaluate against images downsampled with OpenCV’s INTER_CUBIC, and you are measuring the model’s ability to compensate for interpolation kernel differences as much as you are measuring its super-resolution capability. This is not a hypothetical concern — it is the primary reason my PSNR numbers do not match the paper’s, and it is a problem that affects every super-resolution paper that does not specify its exact preprocessing stack.

(I also discovered that PIL and OpenCV disagree on the dimension order of images — PIL uses (width, height) for .size while NumPy/OpenCV use (height, width) for .shape — which is the kind of thing that produces silent bugs in resize operations where you accidentally swap width and height. The resize runs without errors, the output has the right total number of pixels, but the aspect ratio is wrong and your PSNR is garbage. I wrote up the parameter swap issue and the dimension order issue separately because each one cost me at least a full day of debugging.)

Training: Per-Layer Learning Rates

The paper specifies different learning rates for different layers — the reconstruction layer (layer 3) should train with a learning rate 10x smaller than the feature extraction layers:

optimizer = torch.optim.SGD([
    {"params": [model.layer1.weight, model.layer1.bias]},
    {"params": [model.layer2.weight, model.layer2.bias]},
    {"params": model.layer3.parameters(), "lr": 1e-4},
], lr=1e-3, momentum=0.9)

loss_fn = torch.nn.MSELoss()

The rationale: the reconstruction layer operates closest to pixel space, where the gradients are largest and the updates most sensitive to overshooting. A smaller learning rate provides finer control over the final output quality.

Note the learning rates and the absence of weight decay — both of these were bugs I had to fix. More on that in the training bugs section below.

The Normalization Discovery

One of the most instructive debugging sessions during this project: training with raw [0, 255] pixel values caused the loss to explode within the first few batches. The MSE of raw pixel differences is enormous (values in the tens of thousands), and the gradients were too large for the learning rate.

The fix was normalizing to [0, 1]:

def train_one_epoch(model, dataloader, loss_fn, optimizer, device, epoch, total_epochs):
    model.train()
    for batch_idx, (lr_img, hr_img) in enumerate(dataloader):
        lr_img = (lr_img / 255.0).to(device)
        hr_img = (hr_img / 255.0).to(device)

        optimizer.zero_grad()
        output = model(lr_img)
        loss = loss_fn(output, hr_img)
        loss.backward()
        optimizer.step()

This is the same normalization insight I encountered later when building the NumPy neural network for MNIST , where unnormalized pixel values caused NaN within the first few iterations. Numerical stability is never optional.

But the normalization fix created a secondary problem that took me much longer to identify: the paper’s hyperparameters assume [0, 255] pixel ranges. When you normalize to [0, 1], the MSE loss shrinks by a factor of 255^2 = 65,025, and every hyperparameter that was calibrated against [0, 255] gradients is now wrong. The learning rate is too low, the weight initialization is too small, and the weight decay (which I had added as a “reasonable default”) is catastrophically too strong relative to the gradient signal. I did not understand this interdependence until I had spent weeks chasing the symptoms downstream, which is a polite way of saying I wasted a lot of time investigating my evaluation pipeline when the problem was in my training configuration.

Inference: The Sliding Window Problem

SRCNN is trained on 33x33 patches but needs to work on full images. The inference pipeline uses a sliding window approach:

  1. Pad the input image with reflection padding (border = 8 pixels, the amount each side shrinks)
  2. Slide a 33x33 window across the padded image with stride = 17 (the output patch size)
  3. Run each patch through the model to get a 17x17 output
  4. Stitch the 17x17 output patches back together, averaging overlapping regions
def srcnn_inference_full_image(model, lr_upscaled, device, patch_size=33,
                                output_patch=17, stride=17):
    model.eval()
    border = (patch_size - output_patch) // 2  # 8

    h, w, c = lr_upscaled.shape
    padded = np.pad(lr_upscaled,
                    ((border, border), (border, border), (0, 0)),
                    mode='reflect')

    patches, positions = [], []
    for y in range(0, padded.shape[0] - patch_size + 1, stride):
        for x in range(0, padded.shape[1] - patch_size + 1, stride):
            patches.append(padded[y:y+patch_size, x:x+patch_size, :])
            positions.append((y, x))

    # Process in batches for efficiency
    output = np.zeros((h, w, c), dtype=np.float32)
    count = np.zeros((h, w, 1), dtype=np.float32)

    with torch.no_grad():
        for i in range(0, len(patches), 128):
            batch = np.array(patches[i:i+128])
            tensor = torch.from_numpy(batch.transpose(0, 3, 1, 2)).float() / 255.0
            out = (model(tensor.to(device)).cpu().numpy() * 255.0).transpose(0, 2, 3, 1)

            for j, (y, x) in enumerate(positions[i:i+128]):
                output[y:y+output_patch, x:x+output_patch] += out[j]
                count[y:y+output_patch, x:x+output_patch] += 1

    return np.clip(output / np.maximum(count, 1), 0, 255).astype(np.uint8)

The evaluation computes PSNR and SSIM for both the bicubic baseline and the SRCNN output. My first training run (v1) produced results that were worse than bicubic — averaging 20.26 dB vs. bicubic’s 28.61 dB on Set5. The model looked healthy on training patches (30.32 dB) but collapsed on full-image inference. That failure motivated me to investigate the evaluation pipeline itself, leading to the metrics series , but the real problem turned out to be training bugs (see below).

After fixing three training bugs and retraining for 150 epochs (~35 minutes on an RTX 3090 Ti), the v2 model outperforms bicubic on every test image:

DatasetMetricBicubicSRCNN v2Delta
Set5RGB PSNR28.62 dB29.55 dB+0.93 dB
Set5Y PSNR30.40 dB31.45 dB+1.04 dB
Set5SSIM0.8383
Set14RGB PSNR25.73 dB26.40 dB+0.67 dB
Set14Y PSNR27.63 dB28.39 dB+0.76 dB
Set14SSIM0.7407

The model crossed the bicubic baseline at epoch 21 and continued improving through all 150 epochs. All 19 test images (5 from Set5, 14 from Set14) show SRCNN outperforming bicubic — no regressions.

Training Progression

The v2 training run with corrected hyperparameters converged faster and reached higher quality than v1. The model crossed the bicubic baseline (Set5 RGB PSNR 28.62 dB) at epoch 21, indicating that the architecture is capable of learning the super-resolution mapping quickly once the training configuration is correct.

The rapid improvement in the first 20 epochs followed by gradual convergence is typical for CNN training. The total training time of ~35 minutes on an RTX 3090 Ti for 150 epochs reflects the model’s small parameter count (69,251 parameters).

The Three Training Bugs

My v1 model performed worse than bicubic interpolation. The training patches looked fine (30+ dB PSNR), but full-image inference collapsed to 20.26 dB on Set5. I initially blamed the inference pipeline and spent weeks investigating evaluation methodology — which led to genuinely useful research on resizing disagreement and metric pitfalls — but the root cause turned out to be three training bugs:

Bug 1: Weight Decay That Should Not Have Been There

The original SRCNN paper does not use weight decay. I added weight_decay=1e-4 to the optimizer because it seemed like a reasonable default. It was not. With pixel values normalized to [0, 1], the loss values are tiny (on the order of 0.003), and the weight decay penalty was roughly 65,000x too strong relative to the gradient signal. The regularization was actively preventing the model from learning sharp reconstruction filters.

Bug 2: Learning Rate Too Low by 10x

The paper specifies learning rates for [0, 255] pixel values. When normalizing to [0, 1], the MSE loss shrinks by a factor of 255^2 = 65,025, which means the gradients are proportionally smaller. The learning rate needs to be scaled up to compensate. Increasing from 1e-4 to 1e-3 (with the reconstruction layer at 1e-4 instead of 1e-5) restored the effective learning dynamics.

Bug 3: Gaussian(0, 0.001) Initialization

The paper’s initialization — Gaussian with standard deviation 0.001 — was designed for [0, 255] pixel ranges. With [0, 1] normalization, these weights are far too small to produce meaningful activations in the early training steps. Switching to Kaiming/He initialization (which accounts for the ReLU non-linearity and layer fan-in) provided properly scaled initial weights.

The Lesson

All three bugs stem from the same root cause: the paper’s hyperparameters assume [0, 255] pixel values, but I normalized to [0, 1] for training stability. The normalization was correct, but I failed to adjust the hyperparameters that depend on the data scale. This is a common trap when reimplementing papers — the hyperparameters are not independent of the preprocessing, and papers almost never make this dependency explicit.

Why My Numbers Still Differ From the Paper

Even with the training bugs fixed, my v2 numbers do not match the paper’s exactly. I spent a long time thinking this meant my implementation was wrong. It does not. The remaining differences come from three factors, all of which are properties of the evaluation pipeline rather than the model:

  1. Different resizing libraries: The paper used MATLAB’s imresize for bicubic downsampling. I used PIL’s LANCZOS for downsampling and BICUBIC for upsampling. These produce measurably different images — up to 21 dB different . The training data itself is different because it was generated with different interpolation kernels, which means my model learned a slightly different mapping than the original.

  2. Different data types: My pipeline mixed uint8 and float32 at various stages. The 48 dB shift between uint8 and float32 PSNR was one of the first things I had to debug, and even after fixing the most egregious type mismatches, subtle differences in when the uint8 quantization happens can shift PSNR by tenths of a dB.

  3. Different evaluation protocols: The paper evaluated on the Y channel (luminance) of YCbCr color space. My evaluation includes both RGB and Y-channel metrics for comparison. The Y-channel numbers are closer to the paper’s because the color channels carry less structural information and are therefore less sensitive to interpolation differences.

This experience taught me that “reproducing a paper” and “matching a paper’s numbers” are fundamentally different goals. The architecture and training procedure can be correct even when the numbers differ, because the evaluation pipeline is underspecified in most papers. This realization was the origin of my entire investigation into why PSNR scores are not reproducible .

Reimplementing a Paper Is Harder Than It Sounds

I want to say something about the experience of reimplementation itself, because I think it is undervalued and under-discussed.

When you read a paper, the method section gives you the impression that the authors made a series of clean, deliberate choices: this architecture, this loss function, this learning rate, this dataset. When you try to implement the same paper, you discover that the method section is an iceberg — the visible 20% of the decisions that were made, with the remaining 80% left implicit because the authors considered them obvious, or because the authors used a framework (like MATLAB’s Image Processing Toolbox) that made those decisions for them. What interpolation kernel does “bicubic” mean? What happens at image boundaries? Do you anti-alias during downsampling? What coordinate system do you use? What precision are the intermediate values stored in? The paper does not say, because in MATLAB these are all defaults that the authors may not have been consciously choosing.

PyTorch does not have defaults for most of these. PIL has different defaults. OpenCV has different defaults from PIL. And none of them match MATLAB. So “just implement the paper” becomes “reverse-engineer the implicit decisions that shaped the paper’s results, using a different software stack that makes different implicit decisions.” The architecture is the easy part. The preprocessing pipeline is where the papers live and die.

I do not think this is a criticism of the original SRCNN paper, which was clear and well-written by the standards of its time. It is a structural observation about how papers communicate methods: they specify the novel parts (the architecture, the loss function, the training procedure) and leave the commodity parts (image loading, resizing, color space conversion, data type handling) to the reader’s toolkit. When the reader’s toolkit matches the author’s, this works fine. When it does not, you get months of debugging that feel like failures but are actually the most valuable part of the project, because they force you to understand the system at a depth that no tutorial or lecture will ever reach.

What I Learned

  1. The model is 5% of the work. The three-layer CNN took an hour to implement. The dataset pipeline, evaluation infrastructure, and debugging the preprocessing took months.

  2. Papers underspecify preprocessing. “Bicubic downsampling” is not a unique operation. The library, the filter parameters, the anti-aliasing, and the coordinate mapping all matter, and papers rarely specify them. This is not malice — it is the natural consequence of relying on framework defaults that differ across languages and libraries.

  3. Normalization is not optional, but it changes everything downstream. Every numerical pipeline has implicit assumptions about data ranges. Normalizing to [0, 1] is correct for training stability, but it invalidates hyperparameters that were tuned for [0, 255]. Making the data range explicit and consistent prevents the most common class of training failures, but only if you also re-examine every hyperparameter that was calibrated against the old range.

  4. Per-layer learning rates exist for a reason. The reconstruction layer is more sensitive than the feature extraction layers. This insight generalizes: layers closer to the output benefit from more conservative updates.

  5. The investigation is the project. The SRCNN implementation was supposed to be the starting point for more complex SR models. Instead, the investigation into why my metrics did not match became the most valuable research output — leading to the resizing disagreement study, the clean-fid bug fix, and the SIFT alignment pipeline. I went in to build a model and came out with a research agenda about why image processing libraries disagree and what that means for reproducibility.

  6. MATLAB and Python are not interchangeable, and pretending they are is where half the reproducibility problems in computer vision come from. This is not a provocative claim. It is a measurable fact. The same “bicubic resize” in MATLAB and PIL produces images that differ by enough to shift PSNR by double digits. Any paper that reports PSNR without specifying its resizing stack is reporting a number that is not reproducible outside its original software environment.


This post connects to the Image Quality Metrics series : the metric discrepancies discovered during this SRCNN implementation motivated the entire investigation into PSNR limitations, resizing library disagreement, and the clean-fid bug.