Source code for acm.observables.combined

import numpy as np
from pathlib import Path
from .base import Observable
import logging


[docs] class CombinedModel(): """ Class for the combination of theory models. """ def __init__(self, observables: list[Observable]): """ Parameters ---------- observables : list[Observable] List of observables to be combined, initialized with their respective filters. """ self.observables = observables self.models = [obs.model for obs in self.observables]
[docs] def get_prediction(self, x): """ Get the prediction from the model. Parameters ---------- x : array_like Input features. Returns ------- array_like Model prediction, with respective filters applied to each observable. """ return np.concatenate([obs.get_model_prediction(x) for obs in self.observables], axis=-1)
[docs] class CombinedObservable(): """ Class for the combination of observables. It has list properties, that allow to access easily self.observables for readibility. """ def __init__(self, observables: list[Observable]): """ Parameters ---------- observables : list[Observable] List of observables to be combined, initialized with their respective filters. """ self.logger = logging.getLogger(self.__class__.__name__) self.observables = observables self.slice_filters = [obs.slice_filters for obs in self.observables] self.select_filters = [obs.select_filters for obs in self.observables] is_reshaped = [obs.flat_output_dims == 2 for obs in self.observables] if not all(is_reshaped): self.logger.warning("Not all observables have flat_output_dims=2. Some outputs might not be properly reshaped, which might cause concatenation issues.") def __str__(self): """ Returns a string representation of the object (statistic names and slice filters). """ return self.get_save_handle() def __getitem__(self, item): """ Allows to access the observables by their statistic name or their index. """ if isinstance(item, int): return self.observables[item] elif isinstance(item, str): try: idx = self.stat_name.index(item) return self.observables[idx] except ValueError: # If the item is not found in the list, raise an error KeyError(f"Observable with name {item} not found.") else: raise TypeError(f"Item must be an int or str, not {type(item)}.") def __setitem__(self, item, value): """ Allows to set the observable at the given index. """ if not isinstance(value, Observable): raise TypeError(f"Value must be a Observable, not {type(value)}.") if not isinstance(item, int): raise TypeError(f"Item must be an int, not {type(item)}.") self.observables[item] = value def __len__(self): """ Returns the number of observables in the combination. """ return len(self.observables) def __iter__(self): """ Returns an iterator over the observables in the combination. """ return iter(self.observables) def __contains__(self, item): """ Checks if the observable with the given statistic name is in the combination. """ return item in self.stat_name def __reversed__(self): """ Returns a reversed iterator over the observables in the combination. """ return reversed(self.observables) def __add__(self, other): """ Allows to add two CombinedObservable objects together or to add a new observable to the existing Observable. """ if isinstance(other, CombinedObservable): return CombinedObservable(self.observables + other.observables) elif isinstance(other, Observable): return CombinedObservable(self.observables + [other]) else: raise TypeError(f"Cannot add {type(other)} to CombinedObservable.") def __getattr__(self, name): """ Allows to access the observables by their statistic name as an attribute, or the concatenated output of their attributes. """ if name in self.stat_name: idx = self.stat_name.index(name) return self.observables[idx] else: try: return np.concatenate([getattr(obs, name) for obs in self.observables], axis=-1) except AttributeError: raise AttributeError(f"'CombinedObservable' object has no attribute '{name}'") @property def stat_name(self) -> list: """ Name of the statistic. """ return [obs.stat_name for obs in self.observables] @property def x(self) -> np.ndarray: """ Input features (samples). Note: We assume all observable have the same input features, so we just return the first from the list. """ return [obs.x for obs in self.observables][0] @property def x_names(self) -> list: """ Names of the input features. Note: We assume all observable have the same input features, so we just return the first from the list. """ return [obs.x_names for obs in self.observables][0] @property def model(self): """ Theory model of the combination of observables. `model.get_prediction(x)` returns the prediction of the combination of observables, with the respective filters applied to each observable. """ return CombinedModel(self.observables)
[docs] def get_model_prediction(self, x)-> np.ndarray: """ Get the prediction from the model. Parameters ---------- x : array_like Input features. Returns ------- array_like Model prediction. """ return np.concatenate([obs.get_model_prediction(x) for obs in self.observables], axis=-1)
[docs] def get_covariance_matrix( self, volume_factor: float = 64, prefactor: float = 1, ) -> np.ndarray: """ Covariance matrix for the statistic. The prefactor is here for corrections if needed, and the volume factor is the volume correction of the boxes. """ cov_y = self.covariance_y prefactor = prefactor / volume_factor cov = prefactor * np.cov(cov_y, rowvar=False) # rowvar=False : each column is a variable and each row is an observation return cov
[docs] def get_emulator_covariance_matrix(self, prefactor: float = 1) -> np.ndarray: """ Emulator covariance matrix for the statistic. The prefactor is here for corrections if needed. """ cov_y = self.emulator_covariance_y prefactor = prefactor cov = prefactor * np.cov(cov_y, rowvar=False) return cov
[docs] def get_save_handle(self, save_dir: str|Path = None) -> str|Path: """ Creates a handle that combines the handles of the observables, separated by a '+'. They contain the statistic name and the filters used. This can be used to save anything related to this observable. Parameters ---------- save_dir : str Directory where the results will be saved. If provided, the directory is created if it does not exist. If None, the handle is returned as a string. Default is None. Returns ------- str|Path The handle for saving the results, to be completed with the file extension. Returned as a Path instance if save_dir is provided as a Path. """ statistic_handles = [ observable.get_save_handle() for observable in self.observables ] statistic_handle = '+'.join(statistic_handles) if save_dir is None: return statistic_handle # If save_path is provided, make sure it exists Path(save_dir).mkdir(parents=True, exist_ok=True) cout = Path(save_dir) / f'{statistic_handle}' if isinstance(save_dir, str): return cout.as_posix() # Return as string if save_dir is a string return Path(save_dir) / f'{statistic_handle}'