import os
from abc import ABC, abstractmethod
import numpy as np
from scipy.integrate import simpson
from scipy.optimize import minimize_scalar
from .utils import _resolve_cut_name
[docs]
Z_GRIDS = np.linspace(Z_MIN, Z_MAX, NUM_Z_GRIDS)
[docs]
PROBS = np.array([0.025, 0.16, 0.5, 0.84, 0.975], dtype=float)
[docs]
INV1PZ = 1.0 / (1.0 + Z_GRIDS) # precompute once
[docs]
def risk(zx: float, p_norm: np.ndarray) -> float:
# loss = 1 - 1/(1 + (( (zx-z)/(1+z) )/gamma)^2)
dz = (zx - Z_GRIDS) * INV1PZ
t = dz / GAMMA_RISK
t2 = t * t
loss_vec = t2 / (1.0 + t2) # same as 1 - 1/(1+t2)
return float(simpson(p_norm * loss_vec, Z_GRIDS))
[docs]
def get_point_estimate(p):
total = float(np.sum(p))
if (not np.isfinite(total)) or total <= 0.0:
return (np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan)
# normalized pdf for risk
p_norm = p / total
# mode (peak on grid)
# percentiles (CDF from discrete sum)
cdf = np.cumsum(p, dtype=float) / total
zqs = np.interp(PROBS, cdf, Z_GRIDS)
# minimize risk
res = minimize_scalar(
lambda zx: risk(zx, p_norm), bounds=(Z_MIN, Z_MAX), method="bounded"
)
zmode = Z_GRIDS[int(np.argmax(p))]
z025, z160, z500, z840, z975 = zqs
zbest = res.x
return (zmode, z025, z160, z500, z840, z975, zbest)
[docs]
def get_point_estimates_from_pdfs(
pdfs: np.ndarray,
):
"""
Compute point estimates from PDF samples on a redshift grid.
Returns dict of arrays (shape (N,)):
- zmode : z_grid[argmax p(z)]
- z025, z160, z500, z840, z975 : CDF percentiles at [0.025, 0.16, 0.50,
0.84, 0.975]
- zbest : argmin_zx ∫ p_norm(z) * loss(zx,z) dz (bounded to [z_grid[0],
z_grid[-1]])
"""
if pdfs.ndim != 2:
raise ValueError(f"pdfs must be 2D (N, M); got {pdfs.shape}")
N = pdfs.shape[0]
zbest = np.full(N, np.nan, dtype=float)
zmode = np.full(N, np.nan, dtype=float)
z025 = np.full(N, np.nan, dtype=float)
z160 = np.full(N, np.nan, dtype=float)
z500 = np.full(N, np.nan, dtype=float)
z840 = np.full(N, np.nan, dtype=float)
z975 = np.full(N, np.nan, dtype=float)
for i, p in enumerate(pdfs):
zmode[i], z025[i], z160[i], z500[i], z840[i], z975[i], zbest[i] = \
get_point_estimate(p)
return {
"zmode": zmode,
"z025": z025,
"z160": z160,
"z500": z500,
"z840": z840,
"z975": z975,
"zbest": zbest,
}
[docs]
def get_color(
src: np.ndarray,
*,
bands: str = "grizy",
ref_band: str = "i",
mag_zero: float = 30.0,
comp: int = 1,
dg: float = 0.0,
flux_name: str = "gauss2",
include_mag_err: bool = False,
extinction: np.ndarray | None = None,
) -> np.ndarray:
"""
Returns
-------
np.ndarray
If include_mag_err=False: shape (N, 1 + (len(bands)-1))
[ref_mag, (b0-b1), (b1-b2), ...]
If include_mag_err=True: shape (N, 1 + 2*(len(bands)-1))
[ref_mag, (b0-b1), err01, (b1-b2), err12, ...]
"""
fn = _resolve_cut_name(flux_name)
A = 2.5 / np.log(10.0)
n = src.shape[0]
mags: list[np.ndarray] = []
merrs: list[np.ndarray] | None = [] if include_mag_err else None
# Compute mag (and optionally mag_err) per band, in the same order as
# `bands`
for b in bands:
flux_col = f"{b}_flux{fn}"
dflux_col = f"{b}_dflux{fn}_dg{comp}"
err_col = f"{flux_col}_err"
flux_base = src[flux_col]
dflux = src[dflux_col]
ferr = src[err_col]
flux = flux_base + dg * dflux
mag = np.full(n, 40.0, dtype=np.float64) # default to faint mag=40.0
pos = flux > 0
if extinction is None:
with np.errstate(divide="ignore", invalid="ignore"):
mag[pos] = mag_zero - 2.5 * np.log10(flux[pos])
else:
a_ext = extinction[f"a_{b}"]
with np.errstate(divide="ignore", invalid="ignore"):
mag[pos] = mag_zero - 2.5 * np.log10(flux[pos]) - a_ext[pos]
mags.append(mag)
if merrs is not None:
mag_err = np.full(n, 1.0, dtype=np.float64)
with np.errstate(divide="ignore", invalid="ignore"):
mag_err[pos] = A * (ferr[pos] / flux[pos])
merrs.append(mag_err)
nb = len(bands) - 1
ncols = 1 + (2 * nb if include_mag_err else nb)
feat = np.empty((n, ncols), dtype=np.float32)
try:
ref_idx = bands.index(ref_band)
except ValueError:
raise ValueError(f"ref_band={ref_band!r} not found in bands={bands!r}")
feat[:, 0] = mags[ref_idx]
j = 1
if include_mag_err:
assert merrs is not None
for i in range(nb):
np.subtract(mags[i], mags[i + 1], out=feat[:, j])
j += 1
feat[:, j] = np.hypot(merrs[i], merrs[i + 1])
j += 1
else:
for i in range(nb):
np.subtract(mags[i], mags[i + 1], out=feat[:, j])
j += 1
return feat
# ------------------------
# Z-Estimator Implementations
# ------------------------
[docs]
class zEstimator(ABC):
@abstractmethod
[docs]
def get_z(
self,
src: np.ndarray,
*,
mag_zero: float = 30.0,
flux_name: str = "gauss2",
bands: str = "grizy",
ref_band: str = "i",
comp: int = 1,
dg: float = 0.0,
flux_name2: str | None = None,
flux_name3: str | None = None,
extinction: np.ndarray | None = None,
**kwargs,
) -> dict:
"""Method to get redshift point estimates
"""
[docs]
def get_zsel(
self,
src: np.ndarray,
*,
mag_zero: float = 30.0,
flux_name: str = "gauss2",
bands: str = "grizy",
ref_band: str = "i",
comp: int = 1,
dg: float = 0.0,
z_point_name: str = "zmode",
flux_name2: str | None = None,
flux_name3: str | None = None,
extinction: np.ndarray | None = None,
**kwargs,
):
zout = self.get_z(
src=src, mag_zero=mag_zero, flux_name=flux_name,
bands=bands, ref_band=ref_band, comp=comp, dg=dg,
flux_name2=flux_name2, flux_name3=flux_name3,
extinction=extinction, **kwargs,
)
zpoint = zout[z_point_name]
width95 = zout["z975"] - zout["z025"]
return zpoint, width95
[docs]
class flexzboostEstimator(zEstimator):
"""
Wraps a FlexZBoost-like predictor object with a uniform `get_z` API.
"""
def __init__(
self, pz_obj,
):
self.pz_obj.model.models.n_jobs = 1
[docs]
def get_z(
self,
src: np.ndarray,
*,
mag_zero: float = 30.0,
flux_name: str = "gauss2",
bands: str = "grizy",
ref_band: str = "i",
comp: int = 1,
dg: float = 0.0,
flux_name2: str | None = None,
flux_name3: str | None = None,
include_mag_err: bool = False,
return_pdfs: bool = False,
extinction: np.ndarray | None = None,
**kwargs,
) -> dict:
colors = get_color(
src,
mag_zero=mag_zero,
comp=comp,
dg=dg,
flux_name=flux_name,
bands=bands,
ref_band=ref_band,
include_mag_err=include_mag_err,
extinction=extinction,
)
pdfs, _ = self.pz_obj.predict(colors, n_grid=NUM_Z_GRIDS)
points = get_point_estimates_from_pdfs(pdfs)
if return_pdfs:
points["pdfs"] = pdfs
return points
[docs]
def load_bpz_templates(
data_path: str,
bands: str,
filter_name: str = "DC2LSST",
spectra_name: str = "cosmossedswdust136.list",
):
"""Load BPZ template fluxes on Z_GRIDS for provided filter set."""
filters = [f"{filter_name}_{b}" for b in bands]
from desc_bpz.useful_py3 import get_data, get_str, match_resol
z = Z_GRIDS
spectra_file = os.path.join(data_path, "SED", spectra_name)
spectra = [s[:-4] for s in get_str(spectra_file)]
nt = len(spectra)
nf = len(filters)
nz = len(z)
flux_templates = np.zeros((nz, nt, nf))
# # Pre-scan AB dir (kept in case you want to validate presence)
# _ab_file_list = glob.glob(ab_dir + "/*.AB")
# _ab_file_db = [os.path.split(x)[-1] for x in _ab_file_list]
for i, s in enumerate(spectra):
for j, f in enumerate(filters):
model = f"{s}.{f}.AB"
model_path = os.path.join(data_path, "AB", model)
assert os.path.isfile(model_path), "Cannot find model"
zo, f_mod_0 = get_data(model_path, (0, 1))
flux_templates[:, i, j] = match_resol(zo, f_mod_0, z)
return flux_templates
[docs]
class bpzEstimator(zEstimator):
"""
Wraps BPZ template/prior configuration with a uniform `get_z` API.
"""
def __init__(
self,
flux_templates: np.ndarray,
prior_dict: dict,
zp_errors,
):
"""
Parameters
----------
flux_templates : array, shape (NZ, NT, NF)
prior_dict : dict
zp_errors : zero-point mag errors per band, same order as `bands`
"""
[docs]
self.flux_templates = flux_templates
[docs]
self.prior_dict = prior_dict
[docs]
self.zp_errors = np.array(zp_errors, dtype=float)
[docs]
def _measure_one_source(
self,
flux: np.ndarray,
flux_err: np.ndarray,
mag_0: float,
):
from desc_bpz.bpz_tools_py3 import p_c_z_t
from desc_bpz.prior_from_dict import prior_function
nt = self.flux_templates.shape[1]
pczt = p_c_z_t(flux, flux_err, self.flux_templates)
L = pczt.likelihood
P = prior_function(Z_GRIDS, mag_0, self.prior_dict, nt)
post = L * P
pdf = post.sum(axis=1)
return pdf
[docs]
def get_z(
self,
src: np.ndarray,
*,
mag_zero: float = 30.0,
flux_name: str = "gauss2",
bands: str = "grizy",
ref_band: str = "i",
comp: int = 1,
dg: float = 0.0,
return_pdfs: bool = False,
extinction: np.ndarray | None = None,
**kwargs,
) -> dict:
fn = _resolve_cut_name(flux_name)
A = 2.5 / np.log(10.0)
n = src.shape[0]
mags = []
merrs = []
for b in bands:
flux_col = f"{b}_flux{fn}"
dflux_col = f"{b}_dflux{fn}_dg{comp}"
err_col = f"{flux_col}_err"
flux_base = src[flux_col]
dflux = src[dflux_col]
ferr = src[err_col]
flux = flux_base + dg * dflux
mag = np.full(n, 40.0, dtype=np.float64) # default to faint gal
mag_err = np.full(n, 1.0, dtype=np.float64)
pos = flux > 0
with np.errstate(divide="ignore", invalid="ignore"):
mag_err[pos] = A * (ferr[pos] / flux[pos])
if extinction is None:
with np.errstate(divide="ignore", invalid="ignore"):
mag[pos] = mag_zero - 2.5 * np.log10(flux[pos])
else:
a_ext = extinction[f"a_{b}"]
with np.errstate(divide="ignore", invalid="ignore"):
mag[pos] = mag_zero - 2.5 * np.log10(flux[pos]) - a_ext[pos]
mags.append(mag)
merrs.append(mag_err)
mags = np.array(mags).T
merrs = np.array(merrs).T
from desc_bpz.bpz_tools_py3 import e_mag2frac
zp_frac = e_mag2frac(np.array(self.zp_errors))
# Convert to pseudo-fluxes and propagate errors
flux = 10.0**(-0.4 * mags)
flux_err = flux * (10.0**(0.4 * merrs) - 1.0)
add_err = ((zp_frac * flux)**2)
flux_err = np.sqrt(flux_err**2 + add_err)
m_0_col = bands.index(ref_band)
mag0 = mags[:, m_0_col]
# Free some memory
del mags, merrs
ng = len(src)
pdfs = np.stack(
[self._measure_one_source(
flux[i], flux_err[i], mag0[i]
) for i in range(ng)],
dtype=float,
)
points = get_point_estimates_from_pdfs(pdfs)
if return_pdfs:
points["pdfs"] = pdfs
return points