Source code for xlens.simulator.catalog

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Simple example with ring test (rotating intrinsic galaxies)
# Copyright 2023-2025 Xiangchong Li.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
"""Pipeline task that prepares truth catalogs for image simulations."""

import os
from typing import Any

import lsst.pipe.base.connectionTypes as cT
import numpy as np
from lsst.pex.config import Field, FieldValidationError, ListField
from lsst.pipe.base import (
    PipelineTask,
    PipelineTaskConfig,
    PipelineTaskConnections,
    Struct,
)
from lsst.skymap import BaseSkyMap

from ..utils.random import (
    gal_seed_base,
    num_rot,
)
from .galaxies import (
    CatSim2017Catalog,
    Flagship2025Catalog,
    OpenUniverse2024RubinRomanCatalog,
)
from .perturbation import ShearHalo, ShearLogNormalFlat, ShearRedshift


[docs] class CatalogConnections( PipelineTaskConnections, dimensions=("skymap", "tract"), defaultTemplates={ "coaddName": "deep", "simCoaddName": "sim", "mode": 0, "rotId": 0, }, ): """Butler connection definitions for truth catalog generation."""
[docs] skymap = cT.Input( doc="SkyMap to use in processing", name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, storageClass="SkyMap", dimensions=("skymap",), )
[docs] truthCatalog = cT.Output( doc="Output truth catalog", name="{simCoaddName}_{mode}_rot{rotId}_coadd_truthCatalog", storageClass="ArrowAstropy", dimensions=("skymap", "tract"), )
def __init__(self, *, config=None): super().__init__(config=config)
[docs] class CatalogConfig( PipelineTaskConfig, pipelineConnections=CatalogConnections, ): """Configuration options used by :class:`CatalogTask`."""
[docs] catsim_dir = Field[str]( doc="Directory containing input galaxy catalogs.", default=os.environ.get("CATSIM_DIR", "."), )
[docs] galaxy_type = Field[str]( doc="galaxy type", default="catsim2017", )
[docs] layout = Field[str]( doc="layout type", default="random", )
[docs] galId = Field[int]( doc="random seed index for galaxy, 0 <= galId < 10", default=0, )
[docs] rotId = Field[int]( doc="number of rotations", default=0, )
[docs] indice_group_id = Field[int]( doc="indice group index, if <0, ranomly select indices", default=-1, )
[docs] sep_arcsec = Field[float]( doc="Spacing (arcsec) for 'grid'/'hex' layout", default=12.0, )
[docs] extend_ratio = Field[float]( doc="ratio of padded coverage length of galaxy catalog", default=1.08, )
[docs] force_pixel_center = Field[bool]( doc="Force catalog shifts to align with pixel centers.", default=False, )
[docs] apply_lensing_position_shifts = Field[bool]( doc=( "If False, retain original galaxy positions after lensing so that " "image coordinates remain equal to the pre-lensing values." ), default=True, )
[docs] select_observable = ListField[str]( doc=( "Optional catalog observable names used to filter galaxies. " "When provided, each name is paired with the corresponding " "lower and/or upper limits." ), default=[], )
[docs] select_lower_limit = ListField[float]( doc=( "Lower limits for the observables listed in ``select_observable``." "Leave empty to disable minimum filtering for a quantity." ), default=[], )
[docs] select_upper_limit = ListField[float]( doc=( "Upper limits for the observables listed in ``select_observable``." "Leave empty to disable maximum filtering for a quantity." ), default=[], )
[docs] def validate(self): super().validate() if self.galId >= gal_seed_base or self.galId < 0: raise FieldValidationError( self.__class__.galId, self, "We require 0 <= galId < %d" % (gal_seed_base), ) if self.rotId >= num_rot: raise FieldValidationError( self.__class__.rotId, self, f"rotId needs to be smaller than {num_rot}", ) if self.galaxy_type not in [ "catsim2017", "RomanRubin2024", "flagship2025", ]: raise FieldValidationError( self.__class__.galaxy_type, self, "We require galaxy_type in " "['catsim2017', 'RomanRubin2024', 'flagship2025']", ) lists = { "select_observable": self.select_observable, "select_lower_limit": self.select_lower_limit, "select_upper_limit": self.select_upper_limit, } def _is_empty(value): return value is None or len(value) == 0 if any(not _is_empty(v) for v in lists.values()): if _is_empty(self.select_observable): raise FieldValidationError( self.__class__.select_observable, self, "select_observable must be provided when selection limits " "are specified.", ) lengths = { len(v) for v in lists.values() if not _is_empty(v) } if len(lengths) > 1: raise FieldValidationError( self.__class__.select_observable, self, "select_observable, select_lower_limit, and " "select_upper_limit must have identical lengths.", )
[docs] def setDefaults(self): super().setDefaults()
[docs] class CatalogTask(PipelineTask):
[docs] _DefaultName = "CatalogTask"
[docs] ConfigClass = CatalogConfig
"""Task that creates lensed galaxy catalogs for downstream simulations.""" def __init__(self, **kwargs: Any): super().__init__(**kwargs) assert isinstance(self.config, CatalogConfig)
[docs] self.rotate_list = [np.pi / num_rot * i for i in range(num_rot)]
pass
[docs] def get_perturbation_object( self, tract_info, seed: int, **kwargs: Any ) -> object: """Return a perturbation object for lensing the catalog. Must be implemented by subclasses (e.g. shear, halo, log-normal). """ raise NotImplementedError( "'get_perturbation_object' must be implemented by subclasses." )
[docs] def prepare_galaxy_catalog( self, *, seed, tract_info, ): """Instantiate a galaxy catalog class based on the configuration.""" assert isinstance(self.config, CatalogConfig) if self.config.galaxy_type == "catsim2017": GalClass = CatSim2017Catalog elif self.config.galaxy_type == "RomanRubin2024": GalClass = OpenUniverse2024RubinRomanCatalog elif self.config.galaxy_type == "flagship2025": GalClass = Flagship2025Catalog else: raise ValueError("invalid galaxy_type") rng = np.random.RandomState(seed) select_observable = ( list(self.config.select_observable) if self.config.select_observable is not None else [] ) select_lower_limit = ( list(self.config.select_lower_limit) if self.config.select_lower_limit is not None else [] ) select_upper_limit = ( list(self.config.select_upper_limit) if self.config.select_upper_limit is not None else [] ) galaxy_catalog = GalClass( rng=rng, tract_info=tract_info, layout_name=self.config.layout, sep_arcsec=self.config.sep_arcsec, indice_group_id=self.config.indice_group_id, extend_ratio=self.config.extend_ratio, force_pixel_center=self.config.force_pixel_center, catsim_dir=self.config.catsim_dir, select_observable=( select_observable if select_observable else None ), select_lower_limit=( select_lower_limit if select_lower_limit else None ), select_upper_limit=( select_upper_limit if select_upper_limit else None ), ) return galaxy_catalog
[docs] def run( self, *, tract_info, seed: int, **kwargs, ): """Generate a truth catalog with the configured lensing perturbations.""" assert isinstance(self.config, CatalogConfig) galaxy_seed = seed * gal_seed_base + self.config.galId galaxy_catalog = self.prepare_galaxy_catalog( seed=galaxy_seed, tract_info=tract_info, ) theta0 = self.rotate_list[self.config.rotId] galaxy_catalog.rotate(theta0) shear_obj = self.get_perturbation_object(tract_info, seed) galaxy_catalog.lens( shear_obj=shear_obj, apply_position_shifts=self.config.apply_lensing_position_shifts, ) return Struct(truthCatalog=galaxy_catalog.data)
[docs] def runQuantum(self, butlerQC, inputRefs, outputRefs) -> None: assert butlerQC.quantum.dataId is not None inputs = butlerQC.get(inputRefs) assert butlerQC.quantum.dataId is not None inputs["seed"] = butlerQC.quantum.dataId["tract"] skymap = butlerQC.get(inputRefs.skymap) inputs["tract_info"] = skymap[butlerQC.quantum.dataId["tract"]] outputs = self.run(**inputs) butlerQC.put(outputs, outputRefs) return
[docs] class CatalogShearTaskConfig( CatalogConfig, pipelineConnections=CatalogConnections, ): """Configuration for :class:`CatalogShearTask` (constant-shear test)."""
[docs] z_bounds = ListField[float]( doc="boundary list of the redshift", default=[-0.01, 20.0], )
[docs] mode = Field[int]( doc=( "Ternary-encoded shear assignment per z-bin.\n" "Each digit in base-3 is one bin \n" "(lowest-z is least significant digit):\n" " 0 -> -test_value, 1 -> +test_value, 2 -> 0.0\n" "Example: z_bounds=[0.,0.5,1.0,1.5,2.0] (4 bins). \n" "mode=7 -> '0021' (ternary)\n" "=> (-g, -g, 0, +g) for bins: \n" " [0,0.5), [0.5,1.0), [1.0,1.5), [1.5,2.0)." ), default=0, )
[docs] test_target = Field[str]( doc="the shear component to test", default="g1", )
[docs] test_value = Field[float]( doc="absolute value of the shear", default=0.02, )
[docs] kappa_value = Field[float]( doc="kappa value to use, 0. means no kappa", default=0., )
[docs] def validate(self): super().validate() n_zbins = len(self.z_bounds) - 1 if n_zbins < 1: raise FieldValidationError( self.__class__.z_bounds, self, "number of redshif bins: %d is less than 1" % n_zbins, ) mode_max = 3 ** (n_zbins) if self.mode >= mode_max: raise FieldValidationError( self.__class__.mode, self, "mode needs to be smaller than %d" % mode_max, ) if self.test_target not in ["g1", "g2"]: raise FieldValidationError( self.__class__.test_target, self, "test target can only be 'g1' or 'g2'", ) if self.test_value < 0.0 or self.test_value > 0.50: raise FieldValidationError( self.__class__.test_value, self, "test_value should be in [0.00, 0.30]", )
[docs] def setDefaults(self): super().setDefaults()
[docs] class CatalogShearTask(CatalogTask): """Catalog task applying constant shear per redshift bin."""
[docs] _DefaultName = "CatalogShearTask"
[docs] ConfigClass = CatalogShearTaskConfig
def __init__(self, **kwargs: Any): super().__init__(**kwargs)
[docs] def get_perturbation_object(self, tract_info, seed: int, **kwargs: Any): assert isinstance(self.config, CatalogShearTaskConfig) return ShearRedshift( mode=self.config.mode, g_dist=self.config.test_target, shear_value=self.config.test_value, z_bounds=self.config.z_bounds, kappa_value=self.config.kappa_value, )
[docs] class CatalogHaloTaskConfig( CatalogConfig, pipelineConnections=CatalogConnections, ): """Configuration for :class:`CatalogHaloTask` (NFW halo lensing)."""
[docs] mass = Field[float]( doc="halo mass", default=5e14, )
[docs] conc = Field[float]( doc="halo concertration", default=1.0, )
[docs] z_lens = Field[float]( doc="halo redshift", default=1.0, )
[docs] z_source = Field[float]( doc="Fixed redshift for all galaxies. If None, use catalog values.", default=None, optional=True, )
[docs] no_kappa = Field[bool]( doc="whether to exclude kappa field", default=False, )
[docs] def validate(self): super().validate() if self.mass < 1e8: raise FieldValidationError( self.__class__.mass, self, "halo mass too small", ) if self.z_lens < 0 or self.z_lens > 5.0: raise FieldValidationError( self.__class__.z_lens, self, "halo redshift is wrong", ) if self.z_source is not None and self.z_lens > self.z_source: raise FieldValidationError( self.__class__.z_source, self, "halo redshift is larger than source redshift", )
[docs] def setDefaults(self): super().setDefaults()
[docs] class CatalogHaloTask(CatalogTask): """Catalog task applying NFW halo lensing distortions."""
[docs] _DefaultName = "CatalogHaloTask"
[docs] ConfigClass = CatalogHaloTaskConfig
def __init__(self, **kwargs: Any): super().__init__(**kwargs) assert isinstance(self.config, CatalogHaloTaskConfig)
[docs] def prepare_galaxy_catalog( self, *, seed, tract_info, ): assert isinstance(self.config, CatalogHaloTaskConfig) galaxy_catalog = super().prepare_galaxy_catalog( seed=seed, tract_info=tract_info, ) # for fix source redshift if self.config.z_source is not None: galaxy_catalog.set_z_source(self.config.z_source) return galaxy_catalog
[docs] def get_perturbation_object(self, tract_info, seed: int, **kwargs: Any): assert isinstance(self.config, CatalogHaloTaskConfig) return ShearHalo( mass=self.config.mass, conc=self.config.conc, z_lens=self.config.z_lens, no_kappa=self.config.no_kappa, )
[docs] class CatalogLogNormalTaskConfig( CatalogConfig, pipelineConnections=CatalogConnections, ): """Configuration for :class:`CatalogLogNormalTask`."""
[docs] z_source = Field[float]( doc="Fixed redshift for all galaxies.", default=1.0, optional=True, )
[docs] no_kappa = Field[bool]( doc="whether to exclude kappa field", default=False, )
[docs] def validate(self): super().validate()
[docs] def setDefaults(self): super().setDefaults()
[docs] class CatalogLogNormalTask(CatalogTask): """Catalog task applying a log-normal shear field."""
[docs] _DefaultName = "CatalogLogNormalTask"
[docs] ConfigClass = CatalogLogNormalTaskConfig
def __init__(self, **kwargs: Any): super().__init__(**kwargs) assert isinstance(self.config, CatalogLogNormalTaskConfig)
[docs] def prepare_galaxy_catalog( self, *, seed, tract_info, ): assert isinstance(self.config, CatalogLogNormalTaskConfig) galaxy_catalog = super().prepare_galaxy_catalog( seed=seed, tract_info=tract_info, ) # for fix source redshift if self.config.z_source is not None: galaxy_catalog.set_z_source(self.config.z_source) return galaxy_catalog
[docs] def get_perturbation_object(self, tract_info, seed: int, **kwargs: Any): assert isinstance(self.config, CatalogLogNormalTaskConfig) wcs = tract_info.getWcs() scale = float(wcs.getPixelScale().asDegrees()) bbox = tract_info.getBBox() # lsst.geom.Box2I field_size_deg = max(bbox.getHeight(), bbox.getWidth()) * 1.2 * scale npix = int(400 * field_size_deg) return ShearLogNormalFlat( z_source=self.config.z_source, field_size_deg=field_size_deg, npix=npix, seed=seed, )