Source code for xlens.process_pipe.match

# This file is part of pipe_tasks.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

__all__ = [
    "matchPipeConfig",
    "matchPipe",
    "matchPipeConnections",
]

import logging
import os
from typing import Any

import fitsio
import lsst.pipe.base.connectionTypes as cT
import numpy as np
from lsst.pex.config import DictField, Field
from lsst.pipe.base import (
    PipelineTask,
    PipelineTaskConfig,
    PipelineTaskConnections,
    Struct,
)
from lsst.skymap import BaseSkyMap
from lsst.utils.logging import LsstLogAdapter
from numpy.lib import recfunctions as rfn
from numpy.typing import NDArray
from scipy.optimize import linear_sum_assignment
from scipy.spatial import KDTree
from scipy.spatial.distance import cdist

# dm_colnames = [
#     "base_SdssCentroid_x",
#     "base_SdssCentroid_y",
#     "base_GaussianFlux_instFlux",
#     "base_GaussianFlux_instFluxErr",
#     "modelfit_CModel_instFlux",
#     "modelfit_CModel_instFluxErr",
#     "ext_shapeHSM_HsmPsfMoments_xx",
#     "ext_shapeHSM_HsmPsfMoments_yy",
#     "ext_shapeHSM_HsmPsfMoments_xy",
#     "ext_shapeHSM_HigherOrderMomentsPSF_04",
#     "ext_shapeHSM_HigherOrderMomentsPSF_13",
#     "ext_shapeHSM_HigherOrderMomentsPSF_22",
#     "ext_shapeHSM_HigherOrderMomentsPSF_31",
#     "ext_shapeHSM_HigherOrderMomentsPSF_40",
# ]

# "base_Blendedness_abs",
# "base_ClassificationExtendedness_value",
# "base_PsfFlux_instFlux",
# "base_PsfFlux_instFluxErr",
# "base_Variance_value",


[docs] class matchPipeConnections( PipelineTaskConnections, dimensions=("skymap", "tract", "patch"), defaultTemplates={ "coaddName": "deep", }, ):
[docs] skyMap = cT.Input( doc="SkyMap to use in processing", name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, storageClass="SkyMap", dimensions=("skymap",), )
[docs] anacal_catalog = cT.Input( doc="Source catalog with joint detection and measurement", name="{coaddName}_coadd_anacal_catalog", dimensions=("skymap", "tract", "patch"), storageClass="ArrowAstropy", multiple=False, )
[docs] dm_catalog = cT.Input( doc="Catalog containing all the single-band measurement information", name="{coaddName}_coadd_meas", dimensions=("tract", "patch", "band", "skymap"), storageClass="SourceCatalog", multiple=True, deferLoad=True, minimum=0, )
[docs] truth_catalog = cT.Input( doc="Output truth catalog", name="{coaddName}_coadd_truthCatalog", dimensions=("skymap", "tract"), storageClass="ArrowAstropy", multiple=False, deferLoad=False, minimum=0, )
[docs] catalog = cT.Output( doc="Source catalog with joint detection and measurement", name="{coaddName}_coadd_anacal_match", dimensions=("skymap", "tract", "patch"), storageClass="ArrowAstropy", )
def __init__(self, *, config=None): super().__init__(config=config)
[docs] class matchPipeConfig( PipelineTaskConfig, pipelineConnections=matchPipeConnections, ):
[docs] mag_zero = Field[float]( doc="magnitude zero point of the input catalog", default=27.0, )
[docs] mag_max_truth = Field[float]( doc="maximum magnitude limit of truth catalog", default=28.0, )
[docs] do_select_primary = Field[bool]( doc="whether select primary detection", default=False, )
[docs] match_pix_distance = Field[int]( doc="matching distance in pixels", default=6, )
[docs] band_column_names = DictField( keytype=str, itemtype=str, doc="column names for each band", default={ "g": "modelfit_CModel_instFlux, modelfit_CModel_instFluxErr", "r": "modelfit_CModel_instFlux, modelfit_CModel_instFluxErr", "i": "base_SdssCentroid_x, base_SdssCentroid_y, " "base_GaussianFlux_instFlux, base_GaussianFlux_instFluxErr, " "modelfit_CModel_instFlux, modelfit_CModel_instFluxErr, " "base_SdssShape_xx, base_SdssShape_yy, base_SdssShape_xy", "z": "modelfit_CModel_instFlux, modelfit_CModel_instFluxErr", "y": "modelfit_CModel_instFlux, modelfit_CModel_instFluxErr", }, )
[docs] def validate(self): super().validate()
[docs] def setDefaults(self): super().setDefaults()
[docs] class matchPipe(PipelineTask):
[docs] _DefaultName = "matchPipe"
[docs] ConfigClass = matchPipeConfig
def __init__( self, *, config: matchPipeConfig | None = None, log: logging.Logger | LsstLogAdapter | None = None, initInputs: dict[str, Any] | None = None, **kwargs: Any, ): super().__init__( config=config, log=log, initInputs=initInputs, **kwargs ) assert isinstance(self.config, matchPipeConfig)
[docs] self._cat_ref: np.ndarray | None = None
return
[docs] def runQuantum(self, butlerQC, inputRefs, outputRefs): assert isinstance(self.config, matchPipeConfig) inputs = butlerQC.get(inputRefs) tract = int(butlerQC.quantum.dataId["tract"]) patch = int(butlerQC.quantum.dataId["patch"]) skyMap = inputs["skyMap"] dm_handles = inputs["dm_catalog"] if len(dm_handles) == 0: dm_handles_dict = None dm_catalog = None else: dm_handles_dict = { handle.dataId["band"]: handle for handle in dm_handles } dm_catalog = [] for band in dm_handles_dict.keys(): bs = self.config.band_column_names[band] dm_colnames = [c.strip() for c in bs.split(",")] handle = dm_handles_dict[band] cat = handle.get() if self.config.do_select_primary: mask = cat["detect_isPrimary"] cat = rfn.repack_fields( cat.asAstropy().as_array()[dm_colnames][mask] ) else: cat = rfn.repack_fields( cat.asAstropy().as_array()[dm_colnames] ) map_dict = {name: f"{band}_" + name for name in dm_colnames} dm_catalog.append(rfn.rename_fields(cat, map_dict)) dm_catalog = rfn.merge_arrays(dm_catalog, flatten=True) truth_catalog = inputs["truth_catalog"].as_array() anacal_catalog = inputs["anacal_catalog"].as_array() index = np.arange(len(anacal_catalog)) anacal_catalog = rfn.append_fields( base=anacal_catalog, names="index", data=index, dtypes="i4", usemask=False, ) outputs = self.run( skyMap=skyMap, tract=tract, patch=patch, catalog=anacal_catalog, dm_catalog=dm_catalog, truth_catalog=truth_catalog, ) butlerQC.put(outputs, outputRefs) return
[docs] def match(self, ana_coords, mrc_coords, thres=6): mrc_tree = KDTree(mrc_coords) match_dist, match_ndx = mrc_tree.query(ana_coords) # Filter on distance mask = match_dist < thres ana_idx = np.flatnonzero(mask) mrc_idx = match_ndx[mask] # Count how many times each mrc is matched uids, mrc_counts = np.unique(mrc_idx, return_counts=True) repeated_mrc = set(uids[mrc_counts > 1]) if len(repeated_mrc) > 0: # Filter to unique one-to-one matches is_unique = np.array([m not in repeated_mrc for m in mrc_idx]) uniq_ana_idx = ana_idx[is_unique] uniq_mrc_idx = mrc_idx[is_unique] # Get remaining unmatched indices all_ana = set(range(len(ana_coords))) all_mrc = set(range(len(mrc_coords))) used_ana = set(uniq_ana_idx) used_mrc = set(uniq_mrc_idx) remain_ana = np.array(sorted(all_ana - used_ana)) remain_mrc = np.array(sorted(all_mrc - used_mrc)) # Compute distance matrix (only for remaining entries) dist_matrix = cdist(ana_coords[remain_ana], mrc_coords[remain_mrc]) dist_matrix[dist_matrix > thres] = 1e5 finite_rows = np.any(dist_matrix < 1e5, axis=1) finite_cols = np.any(dist_matrix < 1e5, axis=0) if np.any(finite_rows) and np.any(finite_cols): sub_dist = dist_matrix[np.ix_(finite_rows, finite_cols)] row, col = linear_sum_assignment(sub_dist) # Only keep assignments with finite distances valid = sub_dist[row, col] < thres # Recover original indices ana_idx2 = remain_ana[np.flatnonzero(finite_rows)[row[valid]]] mrc_idx2 = remain_mrc[np.flatnonzero(finite_cols)[col[valid]]] else: ana_idx2 = np.array([], dtype=int) mrc_idx2 = np.array([], dtype=int) final_ana_idx = np.concatenate([uniq_ana_idx, ana_idx2]) final_mrc_idx = np.concatenate([uniq_mrc_idx, mrc_idx2]) return final_ana_idx, final_mrc_idx else: return ana_idx, mrc_idx
[docs] def merge_dm(self, src: np.ndarray, mrc: np.ndarray, pixel_scale=0.168): assert isinstance(self.config, matchPipeConfig) magz = self.config.mag_zero mag_mrc = magz - 2.5 * np.log10(mrc["i_base_GaussianFlux_instFlux"]) mrc = mrc[mag_mrc < self.config.mag_max_truth] x_mrc = np.array(mrc["i_base_SdssCentroid_x"]) y_mrc = np.array(mrc["i_base_SdssCentroid_y"]) # Coordinates mrc_coords = np.vstack((x_mrc, y_mrc)).T ana_coords = np.vstack( (src["x1_det"] / pixel_scale, src["x2_det"] / pixel_scale) ).T thres = self.config.match_pix_distance src_idx, mrc_idx = self.match(ana_coords, mrc_coords, thres=thres) final_src = src[src_idx] final_mrc = mrc[mrc_idx] # Combine fields combined = rfn.merge_arrays( (final_src, final_mrc), flatten=True, usemask=False, ) return combined
[docs] def merge_truth( self, src: np.ndarray, mrc: np.ndarray, pixel_scale=0.168, catsim_dir: str | None = None, wcs=None, ): assert isinstance(self.config, matchPipeConfig) if self._cat_ref is None: catsim_dir = catsim_dir or os.environ.get("CATSIM_DIR", ".") path = os.path.join(catsim_dir, "OneDegSq.fits") self.log.info("Caching truth catalog reference from %s", path) self._cat_ref = fitsio.read( path, columns=["i_ab"], ) assert self._cat_ref is not None mag_mrc = self._cat_ref[mrc["indices"]]["i_ab"] mrc = mrc[mag_mrc < self.config.mag_max_truth] assert wcs is not None, "wcs is required for merge_truth" x_mrc, y_mrc = wcs.skyToPixelArray( np.array(mrc["ra"]), np.array(mrc["dec"]), degrees=True, ) # Coordinates ana_coords = np.vstack( (src["x1_det"] / pixel_scale, src["x2_det"] / pixel_scale) ).T mrc_coords = np.vstack((x_mrc, y_mrc)).T thres = self.config.match_pix_distance src_idx, mrc_idx = self.match(ana_coords, mrc_coords, thres=thres) final_src = src[src_idx] final_mrc = mrc[mrc_idx] final_mrc = rfn.repack_fields( final_mrc[["indices", "redshift"]] ) final_mrc = rfn.rename_fields( final_mrc, {"indices": "truth_index"} ) # Combine fields combined = rfn.merge_arrays( (final_src, final_mrc), flatten=True, usemask=False, ) return combined
[docs] def run( self, *, skyMap, tract: int, patch: int, catalog: NDArray, dm_catalog: NDArray | None = None, truth_catalog: NDArray | None = None, catsim_dir: str | None = None, **kwargs, ): assert isinstance(self.config, matchPipeConfig) pixel_scale = ( skyMap[tract][patch].getWcs().getPixelScale().asDegrees() * 3600 ) if dm_catalog is not None: catalog = self.merge_dm(catalog, dm_catalog, pixel_scale) if truth_catalog is not None: wcs = skyMap[tract].getWcs() catalog = self.merge_truth( catalog, truth_catalog, pixel_scale, catsim_dir=catsim_dir, wcs=wcs, ) return Struct(catalog=catalog)