# This file is part of xlens.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Image utilities for working with LSST exposures and PSF models.
This module collects helper routines that are repeatedly used across
``xlens`` when generating or post-processing simulated images. The
implementations originate from the LSST Science Pipelines, and the
docstrings have been expanded here to clarify how they interact with the
rest of ``xlens``.
"""
from typing import Any, List, Sequence
import anacal
import astropy
import lsst.geom as lsst_geom
import numpy as np
from numpy.lib import recfunctions as rfn
from numpy.typing import NDArray
[docs]
badMaskDefault = [
"BAD",
"SAT",
"CR",
"NO_DATA",
"UNMASKEDNAN",
"CROSSTALK",
"INTRP",
"STREAK",
"VIGNETTED",
"CLIPPED",
]
[docs]
def subpixel_shift(image: NDArray, shift_x: float, shift_y: float) -> NDArray:
"""Shift an image by arbitrary subpixel offsets using Fourier methods.
Parameters
----------
image
Two-dimensional array containing the image that should be shifted.
shift_x
Desired shift in the x-direction, expressed in pixel units. The
value can be any real number; positive values move the image towards
larger x.
shift_y
Desired shift in the y-direction, expressed in pixel units. Positive
values move the image towards larger y.
Returns
-------
numpy.ndarray
The shifted image. The output has the same shape as the input and is
guaranteed to be real-valued.
"""
# Get the image size
ny, nx = image.shape
# Create a grid of coordinates in the frequency domain
x = np.fft.fftfreq(nx)
y = np.fft.fftfreq(ny)
X, Y = np.meshgrid(x, y)
# Fourier transform of the image
f_image = np.fft.fft2(image)
# Create the shift phase factor
phase_shift = np.exp(-2j * np.pi * (shift_x * X + shift_y * Y))
# Apply the shift in the frequency domain
f_image_shifted = f_image * phase_shift
# Inverse Fourier transform to get the shifted image
shifted_image = np.fft.ifft2(f_image_shifted)
# Take the real part of the shifted image
shifted_image = np.real(shifted_image)
return shifted_image
[docs]
def resize_array(
array: NDArray[Any],
target_shape: tuple[int, int] = (64, 64),
):
"""Resize an image-like array to a square target shape.
The function first crops the array symmetrically if it is larger than the
requested output size and then applies zero-padding when the array is too
small.
Parameters
----------
array
Input array to resize. The array is assumed to be two-dimensional.
target_shape
Tuple of ``(height, width)`` describing the requested output shape.
Returns
-------
numpy.ndarray
The resized array.
"""
target_height, target_width = target_shape
input_height, input_width = array.shape
# Crop if larger
if input_height > target_height:
start_h = (input_height - target_height) // 2
array = array[start_h : start_h + target_height, :]
if input_width > target_width:
start_w = (input_width - target_width) // 2
array = array[:, start_w : start_w + target_width]
# Pad with zeros if smaller
if input_height < target_height:
pad_height = target_height - input_height
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
array = np.pad(
array,
((pad_bottom, pad_top), (0, 0)),
mode="constant",
constant_values=0.0,
)
if input_width < target_width:
pad_width = target_width - input_width
pad_right = pad_width // 2
pad_left = pad_width - pad_right
array = np.pad(
array,
((0, 0), (pad_left, pad_right)),
mode="constant",
)
return array
[docs]
class LsstPsf(anacal.psf.BasePsf):
"""Adapter that exposes an LSST PSF model with an ``anacal`` interface."""
def __init__(self, psf, npix, lsst_bbox=None):
super().__init__()
[docs]
self.shape = (npix, npix)
if lsst_bbox is None:
self.x_min = 0.0
self.y_min = 0.0
else:
min_corner = lsst_bbox.getMin()
# Get the x_min and y_min
self.x_min = min_corner.getX()
self.y_min = min_corner.getY()
[docs]
def draw(self, x, y):
"""Evaluate the PSF image centered on the requested pixel position."""
this_psf = self.psf.computeImage(
lsst_geom.Point2D(x + self.x_min, y + self.y_min)
).getArray()
this_psf = resize_array(this_psf, self.shape)
return this_psf
[docs]
def truncate_square(arr: NDArray, rcut: int) -> None:
"""Zero out pixels outside a centred square support region.
The function is primarily used when constructing PSF postage stamps. It
enforces a compact support by setting all pixels farther than ``rcut``
from the stamp centre to zero while leaving the inner region untouched.
Parameters
----------
arr : numpy.ndarray
Square, two-dimensional array to modify in place.
rcut : int
Half-width of the square region that should be kept. The resulting
mask spans ``2 * rcut + 1`` pixels in both directions.
Raises
------
ValueError
If ``arr`` is not a square 2-D array or if ``rcut`` is too large for
the provided array size.
"""
if len(arr.shape) != 2 or arr.shape[0] != arr.shape[1]:
raise ValueError("Input array must be a 2D square array")
npix = arr.shape[0]
npix2 = npix // 2
assert rcut < npix2, "truncation radius too large."
if rcut < npix2 - 1:
arr[: npix2 - rcut, :] = 0
arr[npix2 + rcut + 1 :, :] = 0
arr[:, : npix2 - rcut] = 0
arr[:, npix2 + rcut + 1 :] = 0
return
[docs]
def get_psf_array(
*,
lsst_psf,
lsst_bbox,
npix: int,
dg: int = 250,
lsst_mask=None,
):
"""Compute an average PSF image over a regular grid.
The function samples the provided LSST PSF model at a grid of points
across the bounding box and averages the resulting images. Pixels that
are flagged as ``INEXACT_PSF`` in the optional mask are excluded from the
average, mimicking the behaviour in the LSST pipelines.
Parameters
----------
lsst_psf : lsst.meas.algorithms.Psf
LSST PSF model.
lsst_bbox : lsst.geom.Box2I
Bounding box defining the region to evaluate the PSF.
npix : int
Target shape (npix, npix) to which each PSF will be resized.
dg : int, optional
Grid spacing in pixels (default is 250).
lsst_mask : MaskX or None, optional
LSST mask image. If provided, pixels with INEXACT_PSF will be skipped.
Returns
-------
out : numpy.ndarray
Averaged PSF as a 2D array of shape ``(npix, npix)``.
"""
x_min, y_min = lsst_bbox.getMin().getX(), lsst_bbox.getMin().getY()
x_max, y_max = lsst_bbox.getMax().getX(), lsst_bbox.getMax().getY()
# Ensure grid stays within the bbox and aligned with step size
width = (x_max - x_min) // dg * dg
height = (y_max - y_min) // dg * dg
x_array = np.arange(x_min + 20, x_min + width - 20, dg, dtype=int)
y_array = np.arange(y_min + 20, y_min + height - 20, dg, dtype=int)
mask_array = None
out = np.zeros(shape=(npix, npix), dtype=np.float32)
ncount = 0
for yc in y_array:
for xc in x_array:
yim, xim = yc - y_min, xc - x_min
if mask_array is not None and mask_array[yim, xim]:
continue
try:
psf_img = lsst_psf.computeImage(
lsst_geom.Point2D(xc, yc)
).getArray()
out += resize_array(psf_img, (npix, npix))
ncount += 1
except Exception:
continue
if ncount < 2:
raise ValueError("Could not find enough valid PSF samples to average.")
out /= ncount
psf_rcut = npix // 2 - 2
truncate_square(out, psf_rcut)
return out
[docs]
def get_blocks(
*, lsst_psf, lsst_bbox, pixel_scale, npix, psf_array
):
min_corner = lsst_bbox.getMin()
x_min, y_min = min_corner.getX(), min_corner.getY()
width, height = lsst_bbox.getWidth(), lsst_bbox.getHeight()
# Create blocks
blocks = anacal.geometry.get_block_list(
img_ny=height,
img_nx=width,
block_nx=250,
block_ny=250,
block_overlap=80,
scale=pixel_scale,
)
new_blocks = []
for bb in blocks:
# Center of the block
x0 = int(np.clip(bb.xcen, 0, width - 1))
y0 = int(np.clip(bb.ycen, 0, height - 1))
try:
this_psf = lsst_psf.computeImage(
lsst_geom.Point2D(x_min + x0, y_min + y0)
).getArray()
bb.psf_array = resize_array(this_psf, (npix, npix))
except Exception:
continue
new_blocks.append(bb)
return new_blocks
[docs]
def get_blocks_cells(
*, cell_coadd, pixel_scale, npix
):
x_start_coadd = cell_coadd.outer_bbox.beginX
y_start_coadd = cell_coadd.outer_bbox.beginY
blocks = []
for index, cell in enumerate(cell_coadd.cells.values()):
p0 = None
psf_image = getattr(cell, "psf_image", None)
if psf_image is not None:
p0 = getattr(psf_image, "array", None)
if (p0 is not None) and np.isfinite(p0).all():
xmin = cell.outer.bbox.beginX - x_start_coadd
ymin = cell.outer.bbox.beginY - y_start_coadd
xmax = cell.outer.bbox.endX - x_start_coadd
ymax = cell.outer.bbox.endY - y_start_coadd
xmin_in = max(
cell.inner.bbox.beginX - x_start_coadd,
xmin + 10,
)
ymin_in = max(
cell.inner.bbox.beginY - y_start_coadd,
ymin + 10,
)
xmax_in = min(
cell.inner.bbox.endX - x_start_coadd,
xmax - 10,
)
ymax_in = min(
cell.inner.bbox.endY - y_start_coadd,
ymax - 10,
)
xcen = int((xmin + xmax) // 2)
ycen = int((ymin + ymax) // 2)
bb = anacal.geometry.block(
xcen, ycen, xmin, ymin, xmax, ymax, xmin_in, ymin_in, xmax_in,
ymax_in, pixel_scale, index,
)
bb.psf_array = resize_array(
p0,
(npix, npix),
)
norm = np.sum(bb.psf_array)
bb.psf_array = bb.psf_array / norm
blocks.append(bb)
return blocks
[docs]
def stack_psfs_cells(
*, cell_coadd, npix
):
psf_array = np.zeros((npix, npix))
npsf = 0.0
for cell in cell_coadd.cells.values():
p0 = None
psf_image = getattr(cell, "psf_image", None)
if psf_image is not None:
p0 = getattr(psf_image, "array", None)
if (p0 is not None) and np.isfinite(p0).all():
psf_array = psf_array + resize_array(
p0,
(npix, npix),
)
npsf += 1
psf_array = psf_array / npsf
return psf_array
[docs]
def combine_sim_exposures(
exposures: Sequence,
noises: Sequence[NDArray],
):
"""Combine simulated exposures using inverse-variance weights.
"""
if len(exposures) != len(noises):
raise ValueError("exposure and noises should have the same length")
if len(exposures) <= 0:
raise ValueError("no elements in the input list")
reference_shape = exposures[0].getMaskedImage().image.array.shape
combined_image = np.zeros(reference_shape, dtype=np.float32)
combined_noise = np.zeros(reference_shape, dtype=noises[0].dtype)
total_weight = 0.0
for exposure, noise_array in zip(exposures, noises):
image = exposure.getMaskedImage().image.array
variance = exposure.getMaskedImage().variance.array
if image.shape != reference_shape:
raise ValueError("All exposures must share the same image shape")
finite_variance = variance[np.isfinite(variance)]
if finite_variance.size == 0:
raise ValueError(
"Variance plane must contain at least one finite value"
)
variance_value = float(np.nanmean(variance))
if not np.isfinite(variance_value):
raise ValueError("Variance mean must be finite")
if variance_value <= 0:
raise ValueError("Variance values must be positive")
weight = 1.0 / variance_value
combined_image += weight * image
combined_noise += weight * noise_array
total_weight += weight
if total_weight <= 0:
raise ValueError("Total weight must be positive")
combined_image = combined_image / total_weight
combined_noise = combined_noise / total_weight
combined_variance = 1.0 / total_weight
combined_exposure = exposures[0].clone()
image_plane = combined_exposure.getMaskedImage().image.array
variance_plane = combined_exposure.getMaskedImage().variance.array
image_plane[:, :] = combined_image.astype(image_plane.dtype, copy=False)
variance_plane[:, :] = combined_variance
return combined_exposure, combined_noise
[docs]
def rotate_noise_corr(noise_corr):
noise_max = np.amax(noise_corr)
noise_corr = noise_corr / noise_max
ny2, nx2 = noise_corr.shape
assert ny2 % 2 == 1
assert nx2 % 2 == 1
assert noise_corr[ny2 // 2, nx2 // 2] == 1
return np.rot90(m=noise_corr, k=-1)
[docs]
def generate_pure_noise(
*,
ny: int,
nx: int,
pixel_scale: float,
seed: int,
band: str | None,
noise_variance: float,
noise_corr=None,
noiseId: int = 0,
rotId: int = 0,
):
from .random import get_noise_seed
noise_std = np.sqrt(noise_variance)
noise_seed = get_noise_seed(
galaxy_seed=seed,
noiseId=noiseId,
rotId=rotId,
band=band,
is_sim=False,
)
if noise_corr is None:
noise_array = (
np.random.RandomState(noise_seed)
.normal(
scale=noise_std,
size=(ny, nx),
)
.astype(np.float64)
)
else:
noise_corr = rotate_noise_corr(noise_corr)
noise_array = (
anacal.noise.simulate_noise(
seed=noise_seed,
correlation=noise_corr,
nx=nx,
ny=ny,
scale=pixel_scale,
)
* noise_std
)
return noise_array
[docs]
def estimate_noise_variance(exposure, mask_array=None):
if mask_array is None:
mm = (
(exposure.variance.array < 1e5) &
(exposure.mask.array == 0)
)
else:
mm = (
(exposure.variance.array < 1e5) &
(exposure.mask.array == 0) &
(mask_array == 0)
)
if np.sum(mm) < 10:
raise ValueError(
"Do not have enough valid pixels"
)
noise_variance = np.nanmedian(
exposure.variance.array[mm],
)
del mm
if (noise_variance < 1e-10) | (np.isnan(noise_variance)):
raise ValueError(
"the estimated image noise variance should be positive."
)
return noise_variance
[docs]
def prepare_data(
*,
band: str | None,
exposure,
seed: int,
noiseId: int = 0,
rotId: int = 0,
npix: int = 32,
noise_corr: NDArray | None = None,
do_noise_bias_correction: bool = True,
badMaskPlanes: List[str] = badMaskDefault,
skyMap=None,
tract: int = 0,
patch: int = 0,
star_cat: NDArray | None = None,
psf_array: NDArray | None = None,
mask_array: NDArray | None = None,
noise_array: NDArray | None = None,
detection: astropy.table.Table | None = None,
blocks: List | None = None,
**kwargs,
):
"""Collect metadata and auxiliary arrays for shear measurement tasks.
The routine orchestrates several helper utilities in this module to build
a dictionary consumed by the analysis pipeline. It extracts PSF postage
stamps, prepares the galaxy image data, and computes deterministic random
seeds used when adding synthetic noise.
Parameters
----------
band : str
Photometric band label used to tag the output dictionary.
exposure : lsst.afw.image.ExposureF
LSST exposure containing the science image and its associated PSF and
mask information.
seed : int
Base seed that, together with ``noiseId`` and ``rotId``, controls the
stochastic components of the processing.
noiseId : int, optional
Identifier for the noise realisation. Defaults to ``0``.
rotId : int, optional
Identifier for the rotation realisation. Defaults to ``0``.
npix : int, optional
Target size of the PSF postage stamp in pixels. Defaults to ``32``.
noise_corr : numpy.ndarray, optional
Noise correlation function sampled on the same grid as the PSF stamp.
do_noise_bias_correction : bool, optional
If ``True`` (default) include the per-block noise-bias correction
arrays in the output payload.
badMaskPlanes : list of str, optional
Collection of mask plane names that should be treated as invalid.
skyMap : optional
Sky-map descriptor propagated to the output dictionary unchanged.
tract, patch : int, optional
Identifiers for the tract and patch associated with ``exposure``.
star_cat : numpy.ndarray, optional
Catalogue of reference stars used for PSF modelling.
mask_array : numpy.ndarray, optional
Pre-computed boolean mask array. If ``None`` the mask is built from
``exposure`` directly.
noise_array : numpy.ndarray, optional
Pre-computed pure noise array. If ``None`` the mask is built from
``exposure`` directly.
detection : astropy.table.Table, optional
Detection catalogue that provides initial estimates for source
properties.
**kwargs
Additional keyword arguments propagated to downstream consumers.
Returns
-------
dict
A dictionary containing harmonised image data, PSF information, and
metadata ready for the ``anacal`` measurement pipeline.
"""
pixel_scale = float(exposure.getWcs().getPixelScale().asArcseconds())
mag_zero = (
np.log10(exposure.getPhotoCalib().getInstFluxAtZeroMagnitude()) / 0.4
)
wcs = exposure.getWcs()
lsst_bbox = exposure.getBBox()
if psf_array is None:
psf_array = np.asarray(
get_psf_array(
lsst_psf=exposure.getPsf(),
lsst_bbox=lsst_bbox,
npix=npix,
dg=250,
lsst_mask=exposure.mask,
),
dtype=np.float64,
)
gal_array = np.asarray(
exposure.image.array,
dtype=np.float64,
)
if mask_array is None:
bitv = exposure.mask.getPlaneBitMask(badMaskPlanes)
mask_array = (
((exposure.mask.array & bitv) != 0)
| (
exposure.image.array
< (
-6.0
* np.sqrt(
np.where(
exposure.variance.array < 0,
0, exposure.variance.array,
)
)
)
)
).astype(np.int16)
# Set the value inside star mask to zero
anacal.mask.mask_galaxy_image(
gal_array,
mask_array,
False, # extend mask
star_cat,
)
noise_variance = estimate_noise_variance(exposure, mask_array)
if do_noise_bias_correction:
if noise_array is None:
ny, nx = gal_array.shape
noise_array = generate_pure_noise(
ny=ny,
nx=nx,
pixel_scale=pixel_scale,
seed=seed,
band=band,
noise_variance=noise_variance,
noise_corr=noise_corr,
noiseId=noiseId,
rotId=rotId,
)
# apply pixel mask to pure noise image
anacal.mask.mask_galaxy_image(
noise_array,
mask_array,
False, # extend mask
star_cat,
)
else:
noise_array = None
if skyMap is not None:
tractInfo = skyMap[tract]
patchInfo = tractInfo[patch]
else:
tractInfo = None
patchInfo = None
beginx = lsst_bbox.beginX
beginy = lsst_bbox.beginY
if detection is not None:
if isinstance(detection, astropy.table.Table):
detection = detection.copy().as_array()
elif isinstance(detection, np.ndarray):
detection = detection.copy()
assert detection is not None
detection = rfn.repack_fields(
detection[list(anacal.table.column_names())]
)
if blocks is not None:
assert detection is not None
for bb in blocks:
mm = (
(detection["x2_det"] / pixel_scale - beginy >= bb.ymin_in)
& (detection["x2_det"] / pixel_scale - beginy < bb.ymax_in)
& (detection["x1_det"] / pixel_scale - beginx >= bb.xmin_in)
& (detection["x1_det"] / pixel_scale - beginx < bb.xmax_in)
)
detection["block_id"][mm] = bb.index
return {
"pixel_scale": pixel_scale,
"mag_zero": mag_zero,
"noise_variance": noise_variance,
"gal_array": gal_array,
"psf_array": psf_array,
"mask_array": mask_array,
"noise_array": noise_array,
"begin_x": beginx,
"begin_y": beginy,
"wcs": wcs,
"skyMap": skyMap,
"tractInfo": tractInfo,
"patchInfo": patchInfo,
"detection": detection,
"blocks": blocks,
}