Source code for ynot.echelle

"""
echelle
-------

Spectral dispersion and slit length axes are generally not perfectly aligned with the rectilinear pixel grid of a spectrograph detector, complicating the extraction of echelle spectroscopy.  There exists some mapping of each 2D :math:`(x,y)` pixel to a new coordinate system of wavelength and slit position :math:`(\lambda,s)`, with :math:`x` and :math:`y` in units of pixels, :math:`\lambda` in units of Ångstroms, and :math:`s` in units of arcseconds.  These surfaces can therefore be represented as scalar functions over :math:`x` and :math:`y`.  The `ynot` project infers this mapping for all pixels in an echelle order.  For example, this mapping could be parameterized as separable polynomials:

.. math::

   \lambda(x,y) &= \lambda_0 + c_1 x + c_2 x^2 + c_3 y

   s(x,y)      &= s_0 + b_1 y + b_2 x



Echellogram
############
"""

import torch
from torch import nn
from torch.distributions import Normal
import pandas as pd


[docs]class Echellogram(nn.Module): r""" A PyTorch layer that provides a parameter set and transformations to model echellograms. Args: device (str): Either "cuda" for GPU acceleration, or "cpu" otherwise ybounds (tuple): the :math:`y_0` and :math:`y_{max}` of the raw echellogram to analyze. Default: (425, 510) dense_sky (bool): whether or not to treat the sky background as dense (~1400) parameters or fit a few (~3-10) lines from a pre-determined line list. A pre-determined line list is needed for wavelength calibration, while a dense sky gives the best sky subtraction for weak lines that are not in the line list. The best approach---a *hybrid* of the two---is not yet implemented. Default: False """ def __init__(self, device="cuda", ybounds=(425, 510), dense_sky=False): super().__init__() self.device = device self.y0 = ybounds[0] self.ymax = ybounds[1] self.ny = self.ymax - self.y0 self.fiducial = torch.tensor([21779.0, 0.310], device=device).double() self.nx = 1024 self.xvec = torch.arange(0, self.nx, device=device).double() self.xn = ( 2 * (self.xvec - self.xvec.mean()) / (self.xvec.max() - self.xvec.min()) ) self.yvec = torch.arange(0, self.ny, device=device).double() self.xx, self.yy = torch.meshgrid(self.xvec, self.yvec) self.cheb_x = torch.stack( [self.xn ** 0, self.xn ** 1, 2 * self.xn ** 2 - 1] ).to(self.device) # This is sampled in log self.bkg_const = nn.Parameter( torch.tensor(6.9, requires_grad=True, dtype=torch.float64, device=device) ) self.s_coeffs = nn.Parameter( torch.tensor( [14.635, 0.20352, -0.004426, 0.0], requires_grad=True, dtype=torch.float64, device=device, ) ) self.n_amps = 1500 # These represent inputs to log. self.src_amps = nn.Parameter( 4.1 * torch.ones( self.n_amps, requires_grad=True, dtype=torch.float64, device=device ) ) self.smoothness = nn.Parameter( torch.tensor(-2.27, requires_grad=True, dtype=torch.float64, device=device) ) self.lam_coeffs = nn.Parameter( torch.tensor( [0.0, 0.0, 1.0, -0.7923], requires_grad=True, dtype=torch.float64, device=device, ) ) self.p_coeffs = nn.Parameter( torch.tensor( [[3, -1, 0, 0], [9, -1, 0, 0]], requires_grad=True, dtype=torch.float64, device=device, ) ) # Set the s(x,y), and λ(x,y) coordinates self.ss = self.s_of_xy(self.s_coeffs) self.λλ = self.lam_xy(self.lam_coeffs) self.emask = self.edge_mask(self.smoothness) self.λλ_min = self.λλ.min().detach().item() self.λλ_max = self.λλ.max().detach().item() # The wavelength vector should not require grad. self.λ_src_vector = torch.linspace( self.λλ_min, self.λλ_max, self.n_amps, device=self.device, requires_grad=False, dtype=torch.float64, ) # For dense sky sampling (no wavelength calibration) if dense_sky: self.n_sky = self.n_amps self.λ_sky_vector = self.λ_src_vector self.sky_amps = nn.Parameter( 5.2 * torch.ones( self.n_amps, requires_grad=True, dtype=torch.float64, device=device ) ) self.sky_model_function = self.dense_sky_model else: self.λ_sky_vector, peaks = self.get_skyline_wavelengths() self.n_sky = len(self.λ_sky_vector) # These will be treated as natural logs self.sky_amps = nn.Parameter( torch.log(peaks) + 2.3 * torch.ones( self.n_sky, requires_grad=True, dtype=torch.float64, device=device, ) ) # Sampled in log self.sky_continuum_coeffs = nn.Parameter( torch.tensor( [5.7, 0.0, 0.0, 0.0], requires_grad=True, dtype=torch.float64, device=device, ) ) # The grad needlessly adds memory to the backpropagation for this smooth function with torch.no_grad(): self.λn = ( 2 * (self.λλ - self.λλ.mean()) / (self.λλ.max() - self.λλ.min()) ) self.cheb_array = torch.stack( [ torch.ones_like(self.xx, device=device), self.λn, 2 * self.λn ** 2 - 1, 4 * self.λn ** 3 - 3 * self.λn, ] ).to(device) self.sky_model_function = self.sparse_sky_model
[docs] def forward(self, index): """The forward pass of the neural network model Args: index (int): the index of the ABB'A' nod frames: *e.g.* A=0, B=1, B'=2, A'=3 Returns: (torch.tensor): the 2D generative scene model destined for backpropagation parameter tuning """ return self.generative_model(index)
[docs] def s_of_xy(self, params): """ The along-slit coordinate :math:`s` as a function of :math:`(x,y)`, given coefficients Args: params (torch.tensor or tuple): the polynomial weights, first order in :math:`x` and :math:`y` Returns: (torch.tensor): the 2D surface map :math:`s(x,y)` """ y0, kk, dy0_dx, kk_x = params s_out = (kk + kk_x / 100 * self.xx) * ((self.yy - y0) - dy0_dx * self.xx) return s_out
[docs] def edge_mask(self, log_smoothness): r"""The soft-edge pixel mask defined by the extent of the spectrograph slit length Constructed by the product of two sigmoid functions to make a smooth tophat: .. math:: m_e = \mathscr{S}(0) \cdot (1 - \mathscr{S}(12) ) Currently hard-coded to a 12 arcsecond slit. Args: log_smoothness (torch.tensor or tuple): the :math:`\beta` smoothness parameter related to image quality Returns: (torch.tensor): the 2D surface map :math:`m_e(x,y)` """ arg1 = self.ss - 0.0 arg2 = 12.0 - self.ss bottom_edge = torch.sigmoid(arg1 / torch.exp(log_smoothness)) top_edge = torch.sigmoid(arg2 / torch.exp(log_smoothness)) return bottom_edge * top_edge
[docs] def lam_xy(self, c): r"""A 2D Surface mapping :math:`(x,y)` pixels to :math:`\lambda` Each (x,y) pixel coordinate maps to a single central wavelength. This function performs that transformation, given the coefficents of polynomials, the `x` and `y` values, and a fiducial central wavelength and dispersion. The coefficients in this function are intended to be fit through iterative stochastic gradient descent. Args: c (torch.tensor): polynomial weights and bias fitted through backpropagation Returns: (torch.tensor): the 2D surface map :math:`\lambda(x,y)` """ x = (self.xx - self.nx / 2) / (self.nx / 2) y = (self.yy - self.ny / 2) / (self.ny / 2) const = self.fiducial[0] c0 = c[0] # Shift: Angstroms ~[-3, 3] cx1 = ( self.fiducial[1] * (1 + c[1] * 0.01) * self.nx / 2 ) # Dispersion adjustment: Dimensionless ~[-1, 1] cx2 = 1.0 + c[2] # Pixel-dependent dispersion: Angstroms [-1.5, 1.5] # cx3 = c[4] # Higher-order dispersion [-1,1] cy1 = c[3] # Vertically-Tilted straight arclines [Angstroms/pixel] [ -0.3, 0.3] term0 = c0 xterm1 = cx1 * x xterm2 = cx2 * (2 * x ** 2 - 1) # xterm3 = cx3 * (4*x**3 - 3*x) yterm1 = cy1 * y output = const + (term0 + xterm1 + xterm2) + yterm1 return output
[docs] def single_arcline(self, amp, lam_0, lam_sigma): """Evaluate a normalized arcline given a 2D wavelength map""" ln_prob = Normal(loc=lam_0, scale=lam_sigma).log_prob(self.λλ) return amp * torch.exp(ln_prob)
[docs] def native_pixel_model(self, amp_of_lambda, lam_vec): """A Native-pixel model of the scene""" log_scene_cube = Normal( loc=lam_vec.unsqueeze(0).unsqueeze(0), scale=0.42 ).log_prob(self.λλ.unsqueeze(2)) return ( amp_of_lambda.unsqueeze(0).unsqueeze(0) * torch.exp(log_scene_cube) ).sum(axis=2)
[docs] def dense_sky_model(self): """A sky model with dense (~1400) spectral lines""" return self.native_pixel_model(torch.exp(self.sky_amps), self.λ_sky_vector)
[docs] def sparse_sky_model(self): """A sky model with a few (~3-10) spectral lines""" sky_lines = self.native_pixel_model(torch.exp(self.sky_amps), self.λ_sky_vector) sky_continuum = self.sky_continuum_model() return sky_lines + sky_continuum
[docs] def sky_continuum_model(self): """A smooth model for the background sky emission in sparse-sky models Returns: (torch.tensor): the 2D sky emission continuum """ log_sky_cont = ( self.cheb_array * self.sky_continuum_coeffs.unsqueeze(1).unsqueeze(2) ).sum(0) return torch.exp(log_sky_cont)
[docs] def source_profile_simple(self, p_coeffs): """The profile of the sky source, given position and width coefficients and s p_coeffs[0]: Position in arcseconds (0,12) p_coeffs[1]: Width in arcseconds ~1.0 """ sigma = torch.exp(p_coeffs[1]) ln_prob = Normal(loc=p_coeffs[0], scale=sigma).log_prob(self.ss) return torch.exp(ln_prob)
[docs] def source_profile_medium(self, p_coeffs): """The profile of the sky source, given position, width, trend coefficients p_coeffs[0]: Position in arcseconds (0,12) p_coeffs[1]: Width in arcseconds ~1.0 p_coeffs[2]: Position drift as a function of x (arcseconds per pixel) p_coeffs[3]: Position drift as a function of x^2 (arcseconds per pixel^2) """ sigma = torch.exp(p_coeffs[1]) coeffs = p_coeffs[[0, 2, 3]] loc_vector = (coeffs.unsqueeze(1) * self.cheb_x).sum(0) ln_prob = Normal(loc=loc_vector.unsqueeze(1), scale=sigma).log_prob(self.ss) return torch.exp(ln_prob)
[docs] def generative_model(self, index): """The generative model resembles echelle spectra traces in astronomy data""" self.ss = self.s_of_xy(self.s_coeffs) self.λλ = self.lam_xy(self.lam_coeffs) self.emask = self.edge_mask(self.smoothness) sky_model = self.sky_model_function() src_model = self.native_pixel_model(torch.exp(self.src_amps), self.λ_src_vector) src_prof = self.source_profile_medium(self.p_coeffs[index].squeeze()) net_sky = self.emask * sky_model net_src = src_prof * src_model return net_sky + net_src + torch.exp(self.bkg_const)
[docs] def get_skyline_wavelengths(self): """Get the wavelengths of bright sky lines (e.g. OH)""" fn = "/home/gully/GitHub/ynot/data/ir_ohlines.dat" fn_web = 'https://raw.githubusercontent.com/Keck-DataReductionPipelines/NIRSPEC-Data-Reduction-Pipeline/4bf6db52771bdd7a3f4ec80a3418abcc92b2cf43/ir_ohlines.dat' try: df = pd.read_csv( fn, names=["wl", "rel_flux"], delim_whitespace=True, ) except: df = pd.read_csv( fn_web, names=["wl", "rel_flux"], delim_whitespace=True, ) df = df[(df.wl > self.λλ_min) & (df.wl < self.λλ_max)] wls = torch.tensor(df.wl.values, device=self.device, dtype=torch.float64) peaks = torch.tensor( df.rel_flux.values, device=self.device, dtype=torch.float64 ) return (wls, peaks)