Source code for acm.observables.base

import torch
import xarray
import numpy as np
from pathlib import Path 
from sunbird.emulators import FCN
from sunbird.data.data_utils import transform_filters_to_slices
from acm.utils.xarray_data import dataset_from_dict
import logging

[docs] class Observable(): """ Class to load a compressed Observable file or model and apply filters to their outputs. """ def __init__( self, stat_name: str, paths: dict = None, select_filters: dict = None, slice_filters: dict = None, select_indices: list = None, select_indices_on: list = ['y', 'covariance_y', 'emulator_error', 'emulator_covariance_y'], flat_output_dims: int = None, squeeze_output: bool = False, numpy_output: bool = False, ): """ Parameters ---------- stat_name: str Name of the statistic to load. Also the name of the file containing the data. paths: dict, optional Paths to the compressed Observable files or models. If None, the internal dataset will be None. Defaults to None. select_filters : dict, optional Filters to select values in coordinates. Defaults to None. slice_filters : dict, optional Filters to slice values in coordinates. Defaults to None. select_indices : list, optional Indices to select in the flattened data vector. Cannot be used with `select_filters` or `slice_filters`. Defaults to None. select_indices_on : list, optional List of data variables to apply the indices selection on. Defaults to ['y', 'covariance_y', 'emulator_error', 'emulator_covariance_y']. flat_output_dims : int, optional If 2, the output will be flattened on two dimensions (sample and features). If 1, the output will be flattened on a single dimension (dims) - Not recommended. If None, the output will not be flattened. Defaults to None. squeeze_output : bool, optional If True, the output will be squeezed to remove single-dimensional entries. Defaults to False. numpy_output : bool, optional If True, the output will be converted to a numpy array. Defaults to False. Paths ----- The data is expected to be in paths[key]/stat_name.npy, in which an xarray DataSet is stored. The possible keys are: - 'data_dir': directory containing the data (x, y) - 'covariance_dir': directory containing the covariance of the data (covariance_y) - 'error_dir': directory containing the emulator error of the data (emulator_error, emulator_covariance_y) - 'model_dir': directory containing the trained model (model.pth) - 'checkpoint_name': name of the checkpoint file (default: 'model.pth') Example ------- :: slice_filters = {'sep': (0, 0.5),} select_filters = {'multipoles': [0, 2],} will return the summary statistics for `0 < sep < 0.5` and multipoles 0 and 2 """ self.logger = logging.getLogger(self.__class__.__name__) self.stat_name = stat_name self.paths = paths if paths else {} # Ensure dict behavior self.flat_output_dims = flat_output_dims self.numpy_output = numpy_output self.squeeze_output = squeeze_output # Try to read the paths with data inside datasets = [] for key in ['data_dir', 'covariance_dir', 'error_dir']: if key not in self.paths: continue path = Path(self.paths[key]) / f"{self.stat_name}.npy" if path.exists(): datasets.append(dataset_from_dict(np.load(path, allow_pickle=True).item())) self.logger.info(f"Loaded {key} from {path}") if len(datasets) == 0: self.logger.warning("No datasets found within provided paths.") self._dataset = None else: self._dataset = xarray.merge(datasets) self.logger.info("Datasets loaded with the following coordinates: {}".format(list(self._dataset.data_vars.keys()))) # Load the model try: self.model = self.load_model() except Exception as e: # handle the case where the model checkpoint is not found self.logger.warning(f"{e}, model will be undefined. If you are training a new model, this is expected.") # Set the filters self.select_filters = select_filters self.slice_filters = slice_filters self.select_indices = select_indices self.select_indices_on = select_indices_on if select_indices_on else [] # Ensure list behavior if self.select_indices and (self.select_filters or self.slice_filters): self.logger.warning("Indice selection and filters used at the same time. Check what you filter, you might not get the result you expect!") def __str__(self): """ Returns a string representation of the object (statistic names and slice filters). """ # TODO : improve this and __repr__ later ? return self.get_save_handle() def __getattr__(self, name): """ Returns the attribute of the class xarray _dataset, with the filter applied. Also reshapes the output by stacking coordinates (on one or two dims) if flat_output_dims is set. """ # First, apply the filters dataset = self._dataset dataset = self.apply_filters(dataset) data = getattr(dataset, name) # Apply reshaping if name is a data_var if name in self._dataset.data_vars: if name in self.select_indices_on: data = self.apply_indices_selection(data) data = self.flatten_output(data) if self.squeeze_output: data = data.squeeze() if self.numpy_output: data = data.values return data
[docs] @staticmethod def stack_on_attribute(attribute: str|dict, dataarray: xarray.DataArray, **kwargs) -> xarray.DataArray: """ Stacks a DataArray on the dimensions given. Parameters ---------- attribute: str | Mapping The dimension(s) to stack on. If a string, will be read from the DataArray attributes. Will be used as the dim to stack on (see xarray.DataArray.stack) dataarray : xarray.DataArray The DataArray to stack the dimensions on. **kwargs Additional keyword arguments to pass to the stack method. Returns ------- xarray.DataArray The stacked DataArray """ if isinstance(attribute, str): if not attribute in dataarray.attrs or attribute in dataarray.dims: return dataarray attribute_list = [i for i in dataarray.attrs[attribute] if i in dataarray.dims] dim = {attribute: attribute_list} else: dim = attribute dim_name = list(dim.keys())[0] if len(dim[dim_name]) != 0: da = dataarray.stack(**dim, **kwargs) else: da = dataarray.expand_dims(dim_name) return da
[docs] def apply_filters(self, dataarray: xarray.DataArray) -> xarray.DataArray: """ Apply the class filters on a given DataArray or Dataset. Parameters ---------- dataarray : xarray.DataArray The DataArray to apply the filters on. Returns ------- xarray.DataArray The filtered DataArray. """ dimensions = dataarray.dims select_filters = {k: v for k, v in self.select_filters.items() if k in dimensions} if self.select_filters else None slice_filters = {k: v for k, v in self.slice_filters.items() if k in dimensions} if self.slice_filters else None if select_filters: dataarray = dataarray.sel(**select_filters) if slice_filters: slice_filters = transform_filters_to_slices(slice_filters) dataarray = dataarray.sel(**slice_filters) return dataarray
[docs] def flatten_output(self, dataarray: xarray.DataArray) -> xarray.DataArray: """ Flatten the output of a given DataArray by stacking all dimensions over attributes 'sample' and 'features', containing the list of dimensions to stack on. If flat_output_dims is 2, stacks on both 'sample' and 'features' attributes. If flat_output_dims is 1, stacks all dimensions into a single dimension 'dims'. Otherwise, returns the DataArray as is. Parameters ---------- dataarray : xarray.DataArray The DataArray to flatten. Returns ------- xarray.DataArray The flattened DataArray. """ dataarray = dataarray.unstack() if self.flat_output_dims == 2: dataarray = self.stack_on_attribute('sample', dataarray) dataarray = self.stack_on_attribute('features', dataarray) dataarray = dataarray.transpose('sample', 'features') elif self.flat_output_dims == 1: dataarray = dataarray.stack(dims=[...]) return dataarray
[docs] def apply_indices_selection(self, dataarray: xarray.DataArray) -> xarray.DataArray: """ Apply the indices selection on a given DataArray. Should be called after filters are applied and before flattening. Does nothing if select_indices is None. Parameters ---------- dataarray : xarray.DataArray The DataArray to apply the indices selection on. Returns ------- xarray.DataArray The DataArray with the selected indices. """ if self.select_indices is None: return dataarray dataarray = self.stack_on_attribute('features', dataarray) # Warn if filters are applied to features dimensions if self.select_filters: features_filters = [k for k in dataarray.attrs['features'] if k in self.select_filters.keys()] if any(features_filters): self.logger.warning("Filters applied to features dimensions: {}".format(features_filters)) if self.slice_filters: features_filters = [k for k in dataarray.attrs['features'] if k in self.slice_filters.keys()] if any(features_filters): self.logger.warning("Filters applied to features dimensions: {}".format(features_filters)) return dataarray.isel(features=self.select_indices)
[docs] def get_coordinate_list(self, name: str) -> list: """ Returns the list of values of a coordinate of the dataset Parameters ---------- name : str The name of the coordinate to retrieve. Returns ------- list The list of values of the specified coordinate. """ coordinate_list = self.coords[name].values.tolist() if not isinstance(coordinate_list, list): coordinate_list = [coordinate_list] return coordinate_list
@property def x_names(self) -> list: """ Returns the list of the parameters coordinate of the x dataset. Returns ------- list The list of the parameters of the x dataset. """ return self.get_coordinate_list('parameters') @property def emulator_error(self): """ Returns the emulator error of the statistic, with filters applied. Reads the emulator error from the error_dir if it is provided, otherwise uses the get_emulator_error method if implemented. """ if hasattr(self._dataset, 'emulator_error'): data = self._dataset.emulator_error data = self.apply_filters(data) if 'emulator_error' in self.select_indices_on: data = self.apply_indices_selection(data) data = self.flatten_output(data) if self.squeeze_output: data = data.squeeze() if self.numpy_output: data = data.values return data elif hasattr(self, 'get_emulator_error'): return self.get_emulator_error() else: raise NotImplementedError("No emulator error found. Please provide an error_dir or implement the get_emulator_error method.") @property def emulator_covariance_y(self): """ Returns the covariance of the emulator error of the statistic, with filters applied. Reads the emulator covariance from the error_dir if it is provided, otherwise uses the get_emulator_covariance_y method if implemented. """ if hasattr(self._dataset, 'emulator_covariance_y'): data = self._dataset.emulator_covariance_y data = self.apply_filters(data) if 'emulator_covariance_y' in self.select_indices_on: data = self.apply_indices_selection(data) data = self.flatten_output(data) if self.squeeze_output: data = data.squeeze() if self.numpy_output: data = data.values return data elif hasattr(self, 'get_emulator_covariance_y'): return self.get_emulator_covariance_y() else: raise NotImplementedError("No emulator covariance found. Please provide an error_dir or implement the get_emulator_covariance_y method.") @property def checkpoint_fn(self) -> str: """ Path to the checkpoint file of the model, constructed from the paths and the statistic name. """ return self.paths['model_dir'] + f'{self.stat_name}/' + self.paths['checkpoint_name'] # FIXME : Update this format later
[docs] def load_model(self, checkpoint_fn: str = None) -> FCN: """ Trained theory model. """ if checkpoint_fn is None: checkpoint_fn = self.checkpoint_fn # Register safe globals for transform classes to allow loading checkpoints # with PyTorch 2.6+ (which changed weights_only default to True) safe_classes = [] transforms_array_imported = False # Try to import transform classes from sunbird.data.transforms_array try: from sunbird.data.transforms_array import ( LogTransform, ArcsinhTransform, ) safe_classes.extend([LogTransform, ArcsinhTransform]) transforms_array_imported = True except ImportError: pass # Register the classes as safe globals if torch.serialization.add_safe_globals exists if safe_classes: try: torch.serialization.add_safe_globals(safe_classes) except AttributeError: # torch.serialization.add_safe_globals doesn't exist in older PyTorch versions self.logger.debug("torch.serialization.add_safe_globals not available, skipping safe globals registration") # Load the model model = FCN.load_from_checkpoint(checkpoint_fn, strict=True) model.eval().to('cpu') # Set transforms for minkowski models if self.stat_name.startswith('minkowski'): if not transforms_array_imported: # Import if not already done (e.g., if initial import failed) from sunbird.data.transforms_array import WeiLiuInputTransform, WeiLiuOutputTransForm model.transform_output = WeiLiuOutputTransForm() model.transform_input = WeiLiuInputTransform() return model
[docs] def get_model_prediction(self, x, model=None, coords=None, attrs=None, nofilters: bool = False) -> xarray.DataArray: """ Get the prediction from the model. Parameters ---------- x : array_like, dict Input features. model : FCN Trained theory model. If None, the model attribute of the class is used. Defaults to None. coords : dict, optional Coordinates for the output DataArray. If None, the coordinates of _dataset.y are used. Defaults to None. attrs : dict, optional Attributes for the output DataArray. If None, the attributes of _dataset.y are used. Defaults to None. nofilters : bool, optional If True, no filters are applied to the output and the full DataArray is returned. Defaults to False. Returns ------- array_like Model prediction. """ if isinstance(x, dict): missing = set(self.x_names) - set(x.keys()) extra = set(x.keys()) - set(self.x_names) if missing: raise ValueError( "Input x dictionary keys do not match the model input names. " f"Missing keys: {missing}" ) if extra: logger.warning( "Input x dictionary contains unexpected keys not used by the model. " f"Unexpected keys: {extra}" ) x = [x[name] for name in self.x_names] x = np.asarray(x).T # Need to transpose to (n_samples, n_features) else: x = np.asarray(x) # Ensure x is an array to make torch.Tensor faster if model is None: model = self.model with torch.no_grad(): pred = model.get_prediction(torch.Tensor(x)) pred = pred.numpy() if coords is None: coords = { **{k: self._dataset.y.coords[k] for k in self._dataset.y.dims if k in self._dataset.y.attrs['features']} } if attrs is None: attrs = { 'sample': ['n_pred'], 'features': self._dataset.y.attrs['features'], } n_pred = pred.shape[0] if len(pred.shape) > 1 else 1 # Edge case if only one prediction coords = {**{'n_pred': np.arange(n_pred)}, **coords} # Add extra coordinate for the number of predictions pred = pred.reshape([len(c) for c in coords.values()]) # reshape to the right shape pred = xarray.DataArray( pred, coords = coords, attrs = attrs, ) if nofilters: return pred pred = self.apply_filters(pred) pred = self.apply_indices_selection(pred) pred = self.flatten_output(pred) if self.squeeze_output: pred = pred.squeeze() if self.numpy_output: pred = pred.values return pred
[docs] def get_covariance_matrix(self, volume_factor: float = 64, prefactor: float = 1) -> np.ndarray: """ Covariance matrix for the statistic. The prefactor is here for corrections if needed, and the volume factor is the volume correction of the boxes. """ cov_y = self.covariance_y # Ensure 2D shape of the covariance array if isinstance(cov_y, xarray.DataArray): cov_y = cov_y.unstack() cov_y = self.stack_on_attribute('sample', cov_y) cov_y = self.stack_on_attribute('features', cov_y) cov_y = cov_y.transpose('sample', 'features') elif len(cov_y.shape) > 2: self.logger.warning("Covariance array has more than 2 dimensions, reshaping to 2D assuming first dimension is the sample dimension.") cov_y = cov_y.reshape(cov_y.shape[0], -1) # Expect first dimension to be the sample dimension elif len(cov_y.shape) < 2: self.logger.error("Covariance array has less than 2 dimensions, covariance matrix computation might return some unexpected results.") prefactor = prefactor / volume_factor cov = prefactor * np.cov(cov_y, rowvar=False) # rowvar=False : each column is a variable and each row is an observation return cov
[docs] def get_emulator_covariance_matrix(self, prefactor: float = 1) -> np.ndarray: """ Emulator covariance matrix for the statistic. The prefactor is here for corrections if needed. """ cov_y = self.emulator_covariance_y prefactor = prefactor # Ensure 2D shape of the covariance array if isinstance(cov_y, xarray.DataArray): cov_y = cov_y.unstack() cov_y = self.stack_on_attribute('sample', cov_y) cov_y = self.stack_on_attribute('features', cov_y) cov_y = cov_y.transpose('sample', 'features') elif len(cov_y.shape) > 2: self.logger.warning("Covariance array has more than 2 dimensions, reshaping to 2D assuming first dimension is the sample dimension.") cov_y = cov_y.reshape(cov_y.shape[0], -1) # Expect first dimension to be the sample dimension elif len(cov_y.shape) < 2: self.logger.error("Covariance array has less than 2 dimensions, covariance matrix computation might return some unexpected results.") cov = prefactor * np.cov(cov_y, rowvar=False) return cov
[docs] def get_save_handle(self, save_dir: str|Path = None) -> str|Path: """ Creates a handle that includes the statistics and filters used. This can be used to save anything related to this observable. Parameters ---------- save_dir : str Directory where the results will be saved. If provided, the directory is created if it does not exist. If None, the handle is returned as a string. Default is None. Returns ------- str|Path The handle for saving the results, to be completed with the file extension. Returned as a Path instance if save_dir is provided as a Path. """ slice_filters = self.slice_filters statistic_handle = self.stat_name if slice_filters: for key, value in slice_filters.items(): statistic_handle += f'_{key}_{value[0]:.2f}-{value[1]:.2f}' # TODO : add select filters to the handle ? if save_dir is None: return statistic_handle # If save_path is provided, make sure it exists Path(save_dir).mkdir(parents=True, exist_ok=True) cout = Path(save_dir) / f'{statistic_handle}' if isinstance(save_dir, str): return cout.as_posix() # Return as string if save_dir is a string return cout