# This file is part of pipe_tasks.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
__all__ = [
"matchPipeConfig",
"matchPipe",
"matchPipeConnections",
]
import logging
import os
from typing import Any
import fitsio
import lsst.pipe.base.connectionTypes as cT
import numpy as np
from lsst.pex.config import DictField, Field
from lsst.pipe.base import (
PipelineTask,
PipelineTaskConfig,
PipelineTaskConnections,
Struct,
)
from lsst.skymap import BaseSkyMap
from lsst.utils.logging import LsstLogAdapter
from numpy.lib import recfunctions as rfn
from numpy.typing import NDArray
from scipy.optimize import linear_sum_assignment
from scipy.spatial import KDTree
from scipy.spatial.distance import cdist
# dm_colnames = [
# "base_SdssCentroid_x",
# "base_SdssCentroid_y",
# "base_GaussianFlux_instFlux",
# "base_GaussianFlux_instFluxErr",
# "modelfit_CModel_instFlux",
# "modelfit_CModel_instFluxErr",
# "ext_shapeHSM_HsmPsfMoments_xx",
# "ext_shapeHSM_HsmPsfMoments_yy",
# "ext_shapeHSM_HsmPsfMoments_xy",
# "ext_shapeHSM_HigherOrderMomentsPSF_04",
# "ext_shapeHSM_HigherOrderMomentsPSF_13",
# "ext_shapeHSM_HigherOrderMomentsPSF_22",
# "ext_shapeHSM_HigherOrderMomentsPSF_31",
# "ext_shapeHSM_HigherOrderMomentsPSF_40",
# ]
# "base_Blendedness_abs",
# "base_ClassificationExtendedness_value",
# "base_PsfFlux_instFlux",
# "base_PsfFlux_instFluxErr",
# "base_Variance_value",
[docs]
class matchPipeConnections(
PipelineTaskConnections,
dimensions=("skymap", "tract", "patch"),
defaultTemplates={
"coaddName": "deep",
},
):
[docs]
skyMap = cT.Input(
doc="SkyMap to use in processing",
name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
storageClass="SkyMap",
dimensions=("skymap",),
)
[docs]
anacal_catalog = cT.Input(
doc="Source catalog with joint detection and measurement",
name="{coaddName}_coadd_anacal_catalog",
dimensions=("skymap", "tract", "patch"),
storageClass="ArrowAstropy",
multiple=False,
)
[docs]
dm_catalog = cT.Input(
doc="Catalog containing all the single-band measurement information",
name="{coaddName}_coadd_meas",
dimensions=("tract", "patch", "band", "skymap"),
storageClass="SourceCatalog",
multiple=True,
deferLoad=True,
minimum=0,
)
[docs]
truth_catalog = cT.Input(
doc="Output truth catalog",
name="{coaddName}_coadd_truthCatalog",
dimensions=("skymap", "tract"),
storageClass="ArrowAstropy",
multiple=False,
deferLoad=False,
minimum=0,
)
[docs]
catalog = cT.Output(
doc="Source catalog with joint detection and measurement",
name="{coaddName}_coadd_anacal_match",
dimensions=("skymap", "tract", "patch"),
storageClass="ArrowAstropy",
)
def __init__(self, *, config=None):
super().__init__(config=config)
[docs]
class matchPipeConfig(
PipelineTaskConfig,
pipelineConnections=matchPipeConnections,
):
[docs]
mag_zero = Field[float](
doc="magnitude zero point of the input catalog",
default=27.0,
)
[docs]
mag_max_truth = Field[float](
doc="maximum magnitude limit of truth catalog",
default=28.0,
)
[docs]
do_select_primary = Field[bool](
doc="whether select primary detection",
default=False,
)
[docs]
match_pix_distance = Field[int](
doc="matching distance in pixels",
default=6,
)
[docs]
band_column_names = DictField(
keytype=str,
itemtype=str,
doc="column names for each band",
default={
"g": "modelfit_CModel_instFlux, modelfit_CModel_instFluxErr",
"r": "modelfit_CModel_instFlux, modelfit_CModel_instFluxErr",
"i": "base_SdssCentroid_x, base_SdssCentroid_y, "
"base_GaussianFlux_instFlux, base_GaussianFlux_instFluxErr, "
"modelfit_CModel_instFlux, modelfit_CModel_instFluxErr, "
"base_SdssShape_xx, base_SdssShape_yy, base_SdssShape_xy",
"z": "modelfit_CModel_instFlux, modelfit_CModel_instFluxErr",
"y": "modelfit_CModel_instFlux, modelfit_CModel_instFluxErr",
},
)
[docs]
def validate(self):
super().validate()
[docs]
def setDefaults(self):
super().setDefaults()
[docs]
class matchPipe(PipelineTask):
[docs]
_DefaultName = "matchPipe"
[docs]
ConfigClass = matchPipeConfig
def __init__(
self,
*,
config: matchPipeConfig | None = None,
log: logging.Logger | LsstLogAdapter | None = None,
initInputs: dict[str, Any] | None = None,
**kwargs: Any,
):
super().__init__(
config=config, log=log, initInputs=initInputs, **kwargs
)
assert isinstance(self.config, matchPipeConfig)
[docs]
self._cat_ref: np.ndarray | None = None
return
[docs]
def runQuantum(self, butlerQC, inputRefs, outputRefs):
assert isinstance(self.config, matchPipeConfig)
inputs = butlerQC.get(inputRefs)
tract = int(butlerQC.quantum.dataId["tract"])
patch = int(butlerQC.quantum.dataId["patch"])
skyMap = inputs["skyMap"]
dm_handles = inputs["dm_catalog"]
if len(dm_handles) == 0:
dm_handles_dict = None
dm_catalog = None
else:
dm_handles_dict = {
handle.dataId["band"]: handle for handle in dm_handles
}
dm_catalog = []
for band in dm_handles_dict.keys():
bs = self.config.band_column_names[band]
dm_colnames = [c.strip() for c in bs.split(",")]
handle = dm_handles_dict[band]
cat = handle.get()
if self.config.do_select_primary:
mask = cat["detect_isPrimary"]
cat = rfn.repack_fields(
cat.asAstropy().as_array()[dm_colnames][mask]
)
else:
cat = rfn.repack_fields(
cat.asAstropy().as_array()[dm_colnames]
)
map_dict = {name: f"{band}_" + name for name in dm_colnames}
dm_catalog.append(rfn.rename_fields(cat, map_dict))
dm_catalog = rfn.merge_arrays(dm_catalog, flatten=True)
truth_catalog = inputs["truth_catalog"].as_array()
anacal_catalog = inputs["anacal_catalog"].as_array()
index = np.arange(len(anacal_catalog))
anacal_catalog = rfn.append_fields(
base=anacal_catalog, names="index",
data=index, dtypes="i4", usemask=False,
)
outputs = self.run(
skyMap=skyMap,
tract=tract,
patch=patch,
catalog=anacal_catalog,
dm_catalog=dm_catalog,
truth_catalog=truth_catalog,
)
butlerQC.put(outputs, outputRefs)
return
[docs]
def match(self, ana_coords, mrc_coords, thres=6):
mrc_tree = KDTree(mrc_coords)
match_dist, match_ndx = mrc_tree.query(ana_coords)
# Filter on distance
mask = match_dist < thres
ana_idx = np.flatnonzero(mask)
mrc_idx = match_ndx[mask]
# Count how many times each mrc is matched
uids, mrc_counts = np.unique(mrc_idx, return_counts=True)
repeated_mrc = set(uids[mrc_counts > 1])
if len(repeated_mrc) > 0:
# Filter to unique one-to-one matches
is_unique = np.array([m not in repeated_mrc for m in mrc_idx])
uniq_ana_idx = ana_idx[is_unique]
uniq_mrc_idx = mrc_idx[is_unique]
# Get remaining unmatched indices
all_ana = set(range(len(ana_coords)))
all_mrc = set(range(len(mrc_coords)))
used_ana = set(uniq_ana_idx)
used_mrc = set(uniq_mrc_idx)
remain_ana = np.array(sorted(all_ana - used_ana))
remain_mrc = np.array(sorted(all_mrc - used_mrc))
# Compute distance matrix (only for remaining entries)
dist_matrix = cdist(ana_coords[remain_ana], mrc_coords[remain_mrc])
dist_matrix[dist_matrix > thres] = 1e5
finite_rows = np.any(dist_matrix < 1e5, axis=1)
finite_cols = np.any(dist_matrix < 1e5, axis=0)
if np.any(finite_rows) and np.any(finite_cols):
sub_dist = dist_matrix[np.ix_(finite_rows, finite_cols)]
row, col = linear_sum_assignment(sub_dist)
# Only keep assignments with finite distances
valid = sub_dist[row, col] < thres
# Recover original indices
ana_idx2 = remain_ana[np.flatnonzero(finite_rows)[row[valid]]]
mrc_idx2 = remain_mrc[np.flatnonzero(finite_cols)[col[valid]]]
else:
ana_idx2 = np.array([], dtype=int)
mrc_idx2 = np.array([], dtype=int)
final_ana_idx = np.concatenate([uniq_ana_idx, ana_idx2])
final_mrc_idx = np.concatenate([uniq_mrc_idx, mrc_idx2])
return final_ana_idx, final_mrc_idx
else:
return ana_idx, mrc_idx
[docs]
def merge_dm(self, src: np.ndarray, mrc: np.ndarray, pixel_scale=0.168):
assert isinstance(self.config, matchPipeConfig)
magz = self.config.mag_zero
mag_mrc = magz - 2.5 * np.log10(mrc["i_base_GaussianFlux_instFlux"])
mrc = mrc[mag_mrc < self.config.mag_max_truth]
x_mrc = np.array(mrc["i_base_SdssCentroid_x"])
y_mrc = np.array(mrc["i_base_SdssCentroid_y"])
# Coordinates
mrc_coords = np.vstack((x_mrc, y_mrc)).T
ana_coords = np.vstack(
(src["x1_det"] / pixel_scale, src["x2_det"] / pixel_scale)
).T
thres = self.config.match_pix_distance
src_idx, mrc_idx = self.match(ana_coords, mrc_coords, thres=thres)
final_src = src[src_idx]
final_mrc = mrc[mrc_idx]
# Combine fields
combined = rfn.merge_arrays(
(final_src, final_mrc),
flatten=True,
usemask=False,
)
return combined
[docs]
def merge_truth(
self,
src: np.ndarray,
mrc: np.ndarray,
pixel_scale=0.168,
catsim_dir: str | None = None,
wcs=None,
):
assert isinstance(self.config, matchPipeConfig)
if self._cat_ref is None:
catsim_dir = catsim_dir or os.environ.get("CATSIM_DIR", ".")
path = os.path.join(catsim_dir, "OneDegSq.fits")
self.log.info("Caching truth catalog reference from %s", path)
self._cat_ref = fitsio.read(
path,
columns=["i_ab"],
)
assert self._cat_ref is not None
mag_mrc = self._cat_ref[mrc["indices"]]["i_ab"]
mrc = mrc[mag_mrc < self.config.mag_max_truth]
assert wcs is not None, "wcs is required for merge_truth"
x_mrc, y_mrc = wcs.skyToPixelArray(
np.array(mrc["ra"]),
np.array(mrc["dec"]),
degrees=True,
)
# Coordinates
ana_coords = np.vstack(
(src["x1_det"] / pixel_scale, src["x2_det"] / pixel_scale)
).T
mrc_coords = np.vstack((x_mrc, y_mrc)).T
thres = self.config.match_pix_distance
src_idx, mrc_idx = self.match(ana_coords, mrc_coords, thres=thres)
final_src = src[src_idx]
final_mrc = mrc[mrc_idx]
final_mrc = rfn.repack_fields(
final_mrc[["indices", "redshift"]]
)
final_mrc = rfn.rename_fields(
final_mrc,
{"indices": "truth_index"}
)
# Combine fields
combined = rfn.merge_arrays(
(final_src, final_mrc),
flatten=True,
usemask=False,
)
return combined
[docs]
def run(
self,
*,
skyMap,
tract: int,
patch: int,
catalog: NDArray,
dm_catalog: NDArray | None = None,
truth_catalog: NDArray | None = None,
catsim_dir: str | None = None,
**kwargs,
):
assert isinstance(self.config, matchPipeConfig)
pixel_scale = (
skyMap[tract][patch].getWcs().getPixelScale().asDegrees() * 3600
)
if dm_catalog is not None:
catalog = self.merge_dm(catalog, dm_catalog, pixel_scale)
if truth_catalog is not None:
wcs = skyMap[tract].getWcs()
catalog = self.merge_truth(
catalog, truth_catalog, pixel_scale,
catsim_dir=catsim_dir,
wcs=wcs,
)
return Struct(catalog=catalog)