Lecture 2

Image Denoising with SVD

Optimization
Inverse Theory
Python
Torch
SVD
Image denoising and deblurring are important techniques in signal processing and recovery. I this coding exercise, we will explore the application of least squares, SVD, and the pseudoinverse to denoise and deblur images.
Author

Simon Ghyselincks

Published

September 15, 2024

Image Denoising and Deblurring

The motivation for the exercise comes from a real world problem. The Hubble space telescope when launched had a defect in its mirror. This defect caused the images to be blurred. The problem was initially addressed by using signal processing techniques to remove the aberrations from the images.

Point Spread Function

For such an image processing problem, we can consider the continuous incoming light as striking a 2D mirror that distorts the light, followed by a 2D sensor that captures the light. In this context we suppose that we have a noise kernel or a point spread function (PSF) that describes the distortion of the light at the mirror. The point spread function, being a convolution kernel, behaves as a Green’s function for the system in the continuous case:

\[ \vec{b}(x,y) = \int_{\mathcal{X}} \int_{\mathcal{Y}} \vec{G}(x - x', y - y') \vec{u}(x',y') \, dx' dy' \]

where \(\vec{b}(x,y)\) is the blurred image data that is recovered at the sensor, \(\vec{u}(x',y')\) is the true image data, and \(\vec{G}(x,y)\) is the point spread function.

In the special case that the point spread function is \(\delta(x-x',y-y')\), then the image data is not distorted and the sensor captures the true image data. However our experiment is to consider cases where there could be even severe distortions and see how this impacts the proposition of recovering the true image data, \(\vec{u}(x',y')\) from our sensor data, \(\vec{b}(x,y)\).

Discrete PSF

The discrete analog of the continuous PSF can be more conveniently treated with we essentially flatten the the 2D mesh into a 1D vector, a common operation for signal processing. The unflattened case we have:

\[ b_{ij} = \sum_{k=1}^{n} \sum_{l=1}^{m} \Delta x \Delta y G(x_i - x_k, y_j - y_l) u_{kl} \]

where \(b\) is the blurred image data at the sensor, \(u\) is the true image data, and \(G\) is the discrete point spread function. If we flatten the 2D mesh into a 1D vector we can represent this as a 1D convolution operation: \[ \vec{b} = \vec{G} * \vec{u} \]

Since this is a convolution operation, we can process it much more quickly by leveraging the convolution theorem.

\[\begin{align} \mathcal{F}(\vec{b}) &= \mathcal{F}(\vec{G} * \vec{u}) \\ \mathcal{F}(\vec{b}) &= \mathcal{F}(\vec{G}) \mathcal{F}(\vec{u}) \\ \vec{b} &= \mathcal{F}^{-1}(\mathcal{F}(\vec{G}) \odot \mathcal{F}(\vec{u})) \end{align} \]

The \(\odot\) hadamard product is element-wise multiplication, the discrete analog of multiplication of two functions except over an array.

Matrix Representation of Convolution Operation

If we flatten the data down into a 1D vector then it is possible to construct a matrix operator that performs the convolution. This is a Toeplitz matrix, a matrix where each descending diagonal from left to right is constant, so that the row vectors represent a sliding window of the convolution kernel. We can flatten out the PSF and construct the matrix using it as the first row entry and then shifting the PSF to the right to fill out the rest of the rows.

Code Implementation

Show the code
import matplotlib.pyplot as plt
import matplotlib
#matplotlib.use('TkAgg')
import numpy as np
import torch.optim
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import copy

import seaborn as sns

import math
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import torch.fft

We start off by introducing a point spread function within the torch framework. In the case we work with a parameterized gaussian kernel.

Gaussian Example

The multivariate extension of the gaussian function is given by: \[f(x) = \exp\left(-\frac{1}{2} (x-\mu)^T \Sigma^{-1} (x-\mu)\right)\]

where \(\mu\) is the mean vector, \(x\) is a position vector, and \(\Sigma\) is the covariance matrix. The covariance matrix essentially encodes the eigenvectors and corresponding postive eigenvalues of the matrix. The covariance matrix is always symmetric and positive definite. In the context of the code, we are using \(C\) as the inverse of the covariance matrix and working with a \(\mu=0\) value.

Show the code
from scipy.ndimage import convolve

def multivariate_gaussian(pos, mean, cov):
    """Return the multivariate Gaussian distribution on array pos without using einsum notation."""
    n = mean.shape[0]
    diff = pos - mean
    cov_inv = np.linalg.inv(cov)
    
    # Compute the exponent
    diff_cov_inv = diff @ cov_inv
    exponent = -0.5 * np.sum(diff * diff_cov_inv, axis=-1)
    
    # Compute the normalization factor
    norm_factor = np.sqrt((2 * np.pi) ** n * np.linalg.det(cov))
    
    # Return the Gaussian function
    return np.exp(exponent) / norm_factor

# Define the grid limits and resolution
X, Y = np.mgrid[-5:5:0.05, -5:5:0.05]
pos = np.dstack((X, Y))

# Parameters
mean = np.array([0, 0])
eigenvalues = np.array([1, 2])  # Example eigenvalues
principal_axis = np.array([1, 1])  # Example principal axis

# Normalize the principal axis
principal_axis = principal_axis / np.linalg.norm(principal_axis)

# Create the covariance matrix
D = np.diag(eigenvalues)
orthogonal_complement = np.array([-principal_axis[1], principal_axis[0]])
Q = np.column_stack((principal_axis, orthogonal_complement))
cov = Q @ D @ Q.T

# Compute the Gaussian function over the grid
Z = multivariate_gaussian(pos, mean, cov)

# Define the Sobel operators for x and y derivatives
Kdx = np.array([[-1, 0, 1],
                [-2, 0, 2],
                [-1, 0, 1]]) / 4.0

Kdy = np.array([[-1, -2, -1],
                [0,  0,  0],
                [1,  2,  1]]) / 4.0

# Apply the Sobel filters to compute the derivatives
Zdx = convolve(Z, Kdx, mode='constant', cval=0.0)
Zdy = convolve(Z, Kdy, mode='constant', cval=0.0)


plt.contourf(X, Y, Z, levels=20, cmap='viridis')
plt.title('Gaussian Distribution')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.axis('equal')
plt.savefig('figure.png', dpi=300, bbox_inches='tight')

# Plot the Gaussian and its derivatives
plt.figure(figsize=(7.5, 2.5))

# Plot the Gaussian
plt.subplot(1, 3, 1)
plt.contourf(X, Y, Z, levels=20, cmap='viridis')
plt.title('Gaussian Distribution')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.axis('equal')

# Plot the derivative in x
plt.subplot(1, 3, 2)
plt.contourf(X, Y, Zdx, levels=20, cmap='RdBu')
plt.title('Derivative in X (Sobel Filter)')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.axis('equal')

# Plot the derivative in y
plt.subplot(1, 3, 3)
plt.contourf(X, Y, Zdy, levels=20, cmap='RdBu')
plt.title('Derivative in Y (Sobel Filter)')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.axis('equal')

plt.tight_layout()
plt.show()

Multivariate Gaussian and its Derivatives

Extending to Combination of Gaussian and Derivative

We can compute the MV gaussian from the inverse covariance matrix \(C\) with a mean of \(\mu=0\) along with a dimensional scaling metric \(t\). For the purposes of forming interesting and varied PSFs, we include the linear combination of the gaussian and a Sobel operator to axpproximate the derivative of the gaussian.

\[\begin{align} S_x &= \frac{1}{4} \begin{bmatrix} -1 & 0 & 1 \\ -2 & 0 & 2 \\ -1 & 0 & 1 \end{bmatrix} \\ S_y &= \frac{1}{4} \begin{bmatrix} -1 & -2 & -1 \\ 0 & 0 & 0 \\ 1 & 2 & 1 \end{bmatrix} \end{align} \]

These operators act like edge detection or derivatives. The \(n_0\), \(n_x\), and \(n_y\) parameters in the code are used to scale the gaussian and the derivatives.

Show the code
class gaussianConv(nn.Module):
    """
    A PyTorch module that applies a Gaussian convolution to an input image using 
    a parameterized Gaussian Point Spread Function (PSF). The PSF is derived 
    from a covariance matrix and the derivatives of the Gaussian are computed 
    for edge detection.

    Args:
        C (torch.Tensor): Inverse of covariance matrix used to define the shape of the Gaussian.
        t (float, optional): Scaling factor for the Gaussian, default is np.exp(5).
        n0 (float, optional): Scaling factor for the original PSF, default is 1.
        nx (float, optional): Scaling factor for the derivative along the x-axis, default is 1.
        ny (float, optional): Scaling factor for the derivative along the y-axis, default is 1.
    """
    def __init__(self, C, t=np.exp(5), n0=1, nx=1, ny=1):
        super(gaussianConv, self).__init__()

        self.C = C
        self.t = t
        self.n0 = n0
        self.nx = nx
        self.ny = ny

    def forward(self, image):
        """
        Apply the Gaussian convolution and derivatives to an input image.

        This method performs convolution of the input image with a Gaussian
        Point Spread Function (PSF) that includes the original Gaussian and
        its derivatives along x and y axes. The convolution is performed
        using the Fourier Transform for efficiency.

        Args:
            image (torch.Tensor): Input image tensor of shape (Batch, Channels, Height, Width).
        
        Returns:
            torch.Tensor: The convolved image of the same shape as the input.
        """

        # Generate the PSF and calculate the center shift required for alignment
        P, center = self.psfGauss(image.shape[-1], image.device)

        # Shift the PSF so that its center aligns with the origin (top-left corner)
        P_shifted = torch.roll(P, shifts=center, dims=[2, 3])

        # Compute the Fourier Transform of the shifted PSF
        S = torch.fft.fft2(P_shifted)

        # Compute the Fourier Transform of the input image
        I_fft = torch.fft.fft2(image)

        # Multiply the Fourier Transforms element-wise (convolution theorem with Hadamard product)
        B_fft = S * I_fft

        # Compute the inverse Fourier Transform to get back to the spatial domain
        B = torch.real(torch.fft.ifft2(B_fft))

        # Return the convolved image
        return B

    def psfGauss(self, dim, device='cpu'):
        """
        Generate the Gaussian PSF and its derivatives.

        Args:
            dim (int): Dimension size (assumes square dimensions).
            device (str, optional): Device to create tensors on, default is 'cpu'.

        Returns:
            tuple:
                - PSF (torch.Tensor): The combined PSF including derivatives.
                - center (list): Shifts required to align the PSF with the origin.
        """
        # Define the size of the PSF kernel (assumed to be square)
        m = dim
        n = dim

        # Create a meshgrid of (X, Y) coordinates
        x = torch.arange(-m // 2 + 1, m // 2 + 1, device=device)
        y = torch.arange(-n // 2 + 1, n // 2 + 1, device=device)
        X, Y = torch.meshgrid(x, y, indexing='ij')
        X = X.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, m, n)
        Y = Y.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, m, n)

        # Extract elements from the covariance matrix
        # Assuming self.C is a 2x2 tensor
        cx, cy, cxy = self.C[0, 0], self.C[1, 1], self.C[0, 1]

        # Compute the Gaussian PSF using the meshgrid and covariance elements
        PSF = torch.exp(-self.t * (cx * X ** 2 + cy * Y ** 2 + 2 * cxy * X * Y))

        # Normalize the PSF so that its absolute sum is 1
        PSF0 = PSF / torch.sum(PSF.abs())

        # Define derivative kernels (Sobel operators) for edge detection
        Kdx = torch.tensor([[-1, 0, 1],
                            [-2, 0, 2],
                            [-1, 0, 1]], dtype=PSF0.dtype, device=device) / 4
        Kdy = torch.tensor([[-1, -2, -1],
                            [0, 0, 0],
                            [1, 2, 1]], dtype=PSF0.dtype, device=device) / 4

        # Reshape kernels to match convolution requirements
        Kdx = Kdx.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, 3, 3)
        Kdy = Kdy.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, 3, 3)

        # Convolve the PSF with the derivative kernels to obtain derivatives
        # Padding ensures the output size matches the input size
        PSFdx = F.conv2d(PSF0, Kdx, padding=1)
        PSFdy = F.conv2d(PSF0, Kdy, padding=1)

        # Combine the original PSF and its derivatives using the scaling factors
        PSF_combined = self.n0 * PSF0 + self.nx * PSFdx + self.ny * PSFdy

        # Calculate the center shift required to align the PSF with the origin
        center = [1 - m // 2, 1 - n // 2]

        # Return the combined PSF and center shift
        return PSF_combined, center

Creating a Toy Dataset

Often in computational science we test our strategies on toy datasets, simplified data that allows for easier debugging and understanding of the problem at task. In this case, rather than use a real image, we construct a geometric image that will be easier to analyse visually for its correctness when it comes to denoising and deblurring. The dataset is also dimensioned to have a batch and color channel to follow some of the conventions for working with torch tensors, and later some machine learning frameworks. That is \(B \times C \times H \times W\), with a single sample, single channel, and a 256x256 image having dimensions \(1 \times 1 \times 256 \times 256\).

Show the code
x = torch.zeros(1, 1, 256, 256)
x[:,:, 120:140, 120:140] = 1.0
x[:,:, 100:120, 100:120] = -1.0

plt.figure(figsize=(3,3))
plt.imshow(x[0,0,:,:])

A sample toy dataset for image denoising and deblurring.

This simple image is a high and a low signal shown as two square regions, which we will try to recover after applying a point spread function to it (the forward model). The forward model is the convolution of the image with the PSF.

Show the code
C = torch.tensor([[1, 0],[0, 1]])
Amv = gaussianConv(C, t=0.001,n0=0, nx=1,  ny=-1)

y = Amv(x)
plt.subplot(1,2,1)
plt.imshow(x[0,0,:,:])
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(y[0,0,:,:])
plt.colorbar()
print()

Forward model for image denoising and deblurring.

Forming a Convolution Matrix

Back to the idea of forming a Toeplitz matrix, we first flatten the data to 1D and then recover the matrix in one of two ways. We can work in the spatial domain where the first row of the matrix is determined by the 1D convolution for the first element, then slide the row by one to form the matrix. The matrix can be quite large, since an \(n\times m\) image will have \(n \times m\) elements once flattened, requiring a \((n\times m) \times (n\times m)\) matrix. A reduction in dimension to the \(32 \times 32\) image will help with the computation.

Note that we are working with a rolling PSF which has a strange effect in that it assumes a periodic boundary condition in both \(x\) and \(y\). When it comes to convolution, there are many different ways to treat the boundary condition, such as using zero padding or mirroring the boundary. Coding this by hand is a good exercise to understand the convolution operation, but not the purpose of this exercise.

Direct Recovery of Convolution Matrix

Show the code
dim = 32
x = torch.zeros(1, 1, dim, dim)
x[:,:, 12:14, 12:14] = 1.0
x[:,:, 10:12, 10:12] = -1.0

Amv = gaussianConv(C, t=0.1,n0=1, nx=0.1,  ny=0.1)

# Flatten the image and the PSF
x_flat = x.flatten()

kernel, center = Amv.psfGauss(x.shape[-1]) # Get a square conv kernel 

# Since we are using the conv kernel as a filter operation, we use the transpose of the kernel
# to fill the convolution matrix. 

kernel = kernel.transpose(2,3) 
# Roll shifts the kernel from the center of the box to the top left corner
kernel_shifted = torch.roll(kernel, shifts=center, dims=[2, 3])

plt.subplot(1,3,1)
plt.imshow(kernel[0,0,:,:])
plt.title('PSF Centered')
plt.subplot(1,3,2)
plt.title('PSF Shifted with Roll')
plt.imshow(kernel_shifted[0,0,:,:])

# Flatten the kernel
kernel_flat = kernel_shifted.flatten()

# Form the convolution matrix
n = x_flat.shape[0]
m = kernel_flat.shape[0]
A_conv = torch.zeros(n, n)

for i in range(n):
    A_conv[i, :] = torch.roll(kernel_flat, shifts=i, dims=[0])

plt.subplot(1,3,3)
plt.imshow(A_conv)
plt.title('Convolution Matrix');

Forming a convolution matrix for the forward model.

Recovery Using Linearity of Operator

Since the convolution operation that is being performed is linear, one way to recover the matrix operator under this assumption is to pass through the basis vectors and recover the column vectors in this fashion:

\[\begin{bmatrix} a_1 \mid a_2 \mid \ldots \mid a_n \end{bmatrix} \mathbf{e}_i = \mathbf{A} \mathbf{e}_i = \mathbf{a}_i\]

where \(\mathbf{e}_i\) is the \(i\)th basis vector.

Show the code
A_conv_lin = torch.zeros(n, n)

k=0
for i in range(x.shape[-2]):
  for j in range(x.shape[-1]):
    e_ij = torch.zeros_like(x)
    e_ij[:,:, i, j] = 1.0
    y = Amv(e_ij)
    A_conv_lin[:, k] = y.flatten()
    k = k+1

plt.subplot(1,2,1)
plt.imshow(A_conv_lin)
plt.title('Convolution Matrix (Linear)')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(A_conv-A_conv_lin)
plt.title('Difference from Direct')
plt.colorbar()

Forming a convolution matrix for the forward model using linearity.

Now comparing this method against the known convolution result using the class defined earlier with the forward model:

Show the code
b_forward = Amv(x)

b_mat_toeplitz = A_conv @ x_flat
b_mat_linear = A_conv_lin @ x_flat

plt.subplot(1,3,1)
plt.imshow(b_forward[0,0,:,:])
plt.subplot(1,3,2)
plt.imshow(b_mat_toeplitz.reshape(x.shape[-2:]))
plt.subplot(1,3,3) 
plt.imshow(b_mat_linear.reshape(x.shape[-2:]))

We can see that there are some differences between the two methods but in principle they should be the same, (Not sure where the difference is coming from). The important method is actually the one which extracts the columns, as it is more generalizable. So we will continue with that.

Show the code
Amat = A_conv_lin

Least Squares Recovery with SVD and Pseudoinverse

Now that we have a matrix operator recovered we can formulate the forward problem as \(A\mathbf{x} = \mathbf{b}\) with our known \(A\) and \(\mathbf{b}\), and we want to recover \(\mathbf{x}\). To do this we use the SVD decomposition to gather the pseudo inverse. We can decide to filter out some of the singular values that are very small to improve the conditioning on the matrix as well, using a cutoff value for example.

SVD Decomposition

Show the code
U, S, V = torch.svd(Amat.to(torch.float64))
b = Amv(x)

Now we make a log plot of the singular values to see how they decay, noting that we lose numerical precision around the \(10^{-6}\) mark. We can also asses what the frobenius norm of the difference between the original matrix and the reconstructed matrix is to get a sense of the error in the decomposition and reconstruction.

Show the code
plt.semilogy(S)
plt.xlabel('Singular Value Index')
plt.ylabel('Singular Value')

loss = F.mse_loss(Amat, U @ torch.diag(S) @ V.T)
print(f"The loss is {loss}")
The loss is 1.812403923995022e-34

SVD Decomposition of the Convolution Matrix.

The loss is quite small which is a good sign that the decomposition is working well within the numerical precision of the machine.

Initial Attempt at Pseudoinverse

To recover the original image data we first naively try to invert the matrix to see what happens.

Show the code
xhat = torch.linalg.solve(Amat,b.reshape(dim**2))
plt.subplot(1,2,1)
plt.imshow(xhat.reshape(x.shape[-2:]))
plt.title('Naive Inverse')
plt.subplot(1,2,2)
plt.imshow(x.reshape(x.shape[-2:]))
plt.title('Original Image');

Naive Pseudoinverse Recovery of the Original Image.

Wow, not even close! This is because the matrix is so ill conditioned that it is effectively low rank and not invertible. We can improve the situation by filtering out the singular values that are very small.

Pseudoinverse with Filtering

We can filter out the poor conditioning singular values and exclude those values from the inversion. To get an idea of what the values are doing, we can plot the first few singular values and the corresponding singular vector that they project onto. In the case of the SVD the most important information about the matrix is captured in the left-most vectors of the matrix \(U\).

Show the code
n= 5
for i in range(n):
  plt.subplot(1,n,i+1)
  plt.imshow(U[:,i+1].reshape(x.shape[-2:]))
  plt.title(f'Mode {i}')

For the inverse problem, the most import singular values are conversely found in the left-most vectors of the matrix \(V\). We can also check what the right-most vectors are doing, as they will blow up in value when inverting small singular values. They are high frequency modes of the image, creating the reconstruction issues when they are subjected to error in numerical precision.

Show the code
n= 5
for i in range(n):
  plt.subplot(1,n,i+1)
  plt.imshow(V[:,i+1].reshape(x.shape[-2:]))
  plt.title(f'Mode {i}')
plt.show()

for i in range(n):
  plt.subplot(1,n,i+1)
  plt.imshow(V[:,-(i+1)].reshape(x.shape[-2:]))
  plt.title(f'Mode {V.shape[1]-i}')
plt.show()

These modes are the most important ones, as they contain the big-picture detail without the high frequency noise. We can now filter out the singular values that are very small and invert the matrix to recover the original image.

Show the code
b_flat = b.flatten().to(torch.float64)
x_flat = x.flatten().to(torch.float64)
thresholds = [1e-1, 1e-3, 1e-6, 1e-7, 1e-8, 1e-10]

plt.figure(figsize=(7,5))  # Adjust the figure size as needed

for idx, threshold in enumerate(thresholds):
    # Filter the singular values
    S_filtered = S.clone()
    S_filtered[S_filtered < threshold] = 0

    # Compute the reciprocal of the filtered singular values
    S_inv = torch.zeros_like(S_filtered)
    non_zero_mask = S_filtered > 0
    S_inv[non_zero_mask] = 1 / S_filtered[non_zero_mask]

    # Construct the pseudoinverse of Amat
    A_pinv = V @ torch.diag(S_inv) @ U.T

    # Reconstruct the original image
    xhat = A_pinv @ b_flat

    # Compute the reconstruction error
    error = torch.norm(xhat - x_flat, p='fro').item()

    # Plot the reconstructed image in the appropriate subplot
    plt.subplot(2, 3, idx + 1)  # idx + 1 because subplot indices start at 1
    plt.imshow(xhat.reshape(x.shape[-2:]))
    plt.title(f'Threshold {threshold}\nError: {error:.4f}')
    plt.colorbar()
    plt.axis('off')  # Optionally turn off axis ticks and labels

plt.tight_layout()
plt.show()

Pseudoinverse Recovery of the Original Image with Filtering.

Looking at the results, around the \(10^{-7}\) mark we start to a peak level of recovery, as measured by the error in the Frobenius norm of the reconstruction. But what happens when we add noise to the data signal?

Adding Noise to the Signal

Now we add some noise to the signal and try least squares again for the direct solution

Show the code
b_flat = b.flatten().to(torch.float64)
x_flat = x.flatten().to(torch.float64)
Amat = Amat.to(torch.float64)

alpha = .01
noise = torch.randn_like(b_flat) * alpha

H = Amat.T @ Amat + alpha**2 * torch.eye(Amat.shape[0])
xhat = torch.linalg.solve(H, Amat.T @ (b_flat + noise))

plt.subplot(1,2,1)
plt.imshow(x[0,0])
plt.title('Original Image')
plt.subplot(1,2,2)
plt.imshow(xhat.reshape(x.shape[-2:]))
plt.title('Reconstructed Image');

Pseudoinverse Recovery of the Original Image with Noise.

The reconstruction is not very good, the noise has been amplifed all over the image. We can try the pseudoinverse method again with the noise added to the signal.

Show the code
Amat_noisy = Amat + alpha * torch.eye(Amat.shape[0])
Un, Sn, Vn = torch.svd(Amat_noisy)

thresholds = [.5, .1, .05, .03, .005, .001]

plt.figure(figsize=(7,5))  # Adjust the figure size as needed

for idx, threshold in enumerate(thresholds):
    # Filter the singular values
    S_filtered = Sn.clone()
    S_filtered[S_filtered < threshold] = 0

    # Compute the reciprocal of the filtered singular values
    S_inv = torch.zeros_like(S_filtered)
    non_zero_mask = S_filtered > 0
    S_inv[non_zero_mask] = 1 / S_filtered[non_zero_mask]

    # Construct the pseudoinverse of Amat
    A_pinv = Vn @ torch.diag(S_inv) @ Un.T

    # Reconstruct the original image
    xhat = A_pinv @ (b_flat + noise)

    # Compute the reconstruction error
    error = torch.norm(xhat - x_flat, p='fro').item()

    # Plot the reconstructed image in the appropriate subplot
    plt.subplot(2, 3, idx + 1)  # idx + 1 because subplot indices start at 1
    plt.imshow(xhat.reshape(x.shape[-2:]))
    plt.title(f'Threshold {threshold}\nError: {error:.4f}')
    plt.colorbar()
    plt.axis('off')  # Optionally turn off axis ticks and labels

plt.tight_layout()
plt.show()

Pseudoinverse Recovery of the Original Image with Noise.

The small addition of noise is quite significant in the recovery threshold for reconstruction. Using a higher threshold for the singular values becomes important when dealing with noise in the signal. Previously numerical precision was the main issue, but now the measurement noise is the main issue.