Source code for console.spcm_control.acquisition_control

"""Acquisition Control Class."""

import logging
import logging.config
import os
import time
from datetime import datetime
from pathlib import Path

import numpy as np
from scipy import signal

from console.interfaces.acquisition_data import AcquisitionData
from console.interfaces.acquisition_parameter import AcquisitionParameter, DDCMethod
from console.interfaces.dimensions import Dimensions
from console.interfaces.unrolled_sequence import UnrolledSequence
from console.pulseq_interpreter.sequence_provider import Sequence, SequenceProvider
from console.spcm_control.rx_device import RxCard
from console.spcm_control.tx_device import TxCard
from console.utilities import ddc
from console.utilities.load_config import get_instances

LOG_LEVELS = [
    logging.DEBUG,
    logging.INFO,
    logging.WARNING,
    logging.ERROR,
    logging.CRITICAL,
]


[docs] class AcquisitionControl: """Acquisition control class. The main functionality of the acquisition control is to orchestrate transmit and receive cards using ``TxCard`` and ``RxCard`` instances. """ def __init__( self, configuration_file: str, nexus_data_dir: str = os.path.join(Path.home(), "nexus-console"), file_log_level: int = logging.INFO, console_log_level: int = logging.INFO, ): """Construct acquisition control class. Create instances of sequence provider, tx and rx card. Setup the measurement cards and get parameters required for a measurement. Parameters ---------- configuration_file Path to configuration yaml file which is used to create measurement card and sequence provider instances. nexus_data_dir: Nexus console default directory to store logs, states and acquisition data. If none, the default directory is create in the home directory, default is None. file_log_level Set the logging level for log file. Logfile is written to the session folder. console_log_level Set the logging level for the terminal/console output. """ # Create session path (contains all acquisitions of one day) session_folder_name = datetime.now().strftime("%Y-%m-%d") + "-session/" self.session_path = os.path.join(nexus_data_dir, session_folder_name) os.makedirs(self.session_path, exist_ok=True) self._setup_logging(console_level=console_log_level, file_level=file_log_level) self.log = logging.getLogger("AcqCtrl") self.log.info("--- Acquisition control started\n") # Get instances from configuration file ctx = get_instances(configuration_file) self.seq_provider: SequenceProvider = ctx[0] self.tx_card: TxCard = ctx[1] self.rx_card: RxCard = ctx[2] self.seq_provider.output_limits = self.tx_card.max_amplitude # Setup the cards self.is_setup: bool = False if self.tx_card.connect() and self.rx_card.connect(): self.log.info("Setup of measurement cards successful.") self.is_setup = True # Get the rx sampling rate for DDC self.f_spcm = self.rx_card.sample_rate * 1e6 # Set sequence provider max. amplitude per channel according to values from tx_card self.seq_provider.max_amp_per_channel = self.tx_card.max_amplitude self.sequence: UnrolledSequence | None = None # Attributes for data and dwell time of downsampled signal self._raw: list[np.ndarray] = [] self._unproc: list[np.ndarray] = [] def __del__(self): """Class destructor disconnecting measurement cards.""" if self.tx_card: self.tx_card.disconnect() if self.rx_card: self.rx_card.disconnect() self.log.info("Measurement cards disconnected") self.log.info("Acquisition control terminated\n---------------------------------------------------\n") def _setup_logging(self, console_level: int, file_level: int) -> None: # Check if log levels are valid if console_level not in LOG_LEVELS: raise ValueError("Invalid console log level") if file_level not in LOG_LEVELS: raise ValueError("Invalid file log level") # Disable existing loggers logging.config.dictConfig({"version": 1, "disable_existing_loggers": True}) # type: ignore[attr-defined] # Set up logging to file logging.basicConfig( level=file_level, format="%(asctime)s %(name)-7s: %(levelname)-8s >> %(message)s", datefmt="%d-%m-%Y, %H:%M", filename=f"{self.session_path}console.log", filemode="a", ) # Define a Handler which writes INFO messages or higher to the sys.stderr log_console = logging.StreamHandler() log_console.setLevel(console_level) formatter = logging.Formatter("%(name)-7s: %(levelname)-8s >> %(message)s") log_console.setFormatter(formatter) logging.getLogger("").addHandler(log_console)
[docs] def set_sequence(self, sequence: str | Sequence, parameter: AcquisitionParameter) -> None: """Set sequence and acquisition parameter. Parameters ---------- sequence Path to pulseq sequence file. parameter Set of acquisition parameters which are required for the acquisition. Raises ------ AttributeError Invalid sequence provided. FileNotFoundError Invalid file ending of sequence file. """ try: # Check sequence if isinstance(sequence, Sequence): self.seq_provider.from_pypulseq(sequence) elif isinstance(sequence, str): if not sequence.endswith(".seq"): raise FileNotFoundError("Invalid sequence file.") self.seq_provider.read(sequence) except (FileNotFoundError, AttributeError) as err: self.log.exception(err, exc_info=True) raise err # Reset unrolled sequence self.sequence = None self.log.info( "Unrolling sequence: %s", self.seq_provider.definitions["Name"].replace(" ", "_"), ) # Update sequence parameter hash and calculate sequence self.sequence = self.seq_provider.unroll_sequence(parameter=parameter) self.log.info("Sequence duration: %s s", self.sequence.duration)
[docs] def run(self) -> AcquisitionData: """Run an acquisition job. Raises ------ RuntimeError The measurement cards are not setup properly ValueError Missing raw data or missing averages """ try: # Check setup if not self.is_setup: raise RuntimeError("Measurement cards are not setup.") if self.sequence is None: raise ValueError("No sequence set, call set_sequence() to set a sequence and acquisition parameter.") except (RuntimeError, ValueError) as err: self.log.exception(err, exc_info=True) raise err # Define timeout for acquisition process: 5 sec + sequence duration timeout = 5 + self.sequence.duration self._unproc = [] self._raw = [] # Set gradient offset values self.tx_card.set_gradient_offsets( self.sequence.parameter.gradient_offset, self.seq_provider.high_impedance[1:] ) for k in range(self.sequence.parameter.num_averages): self.log.info("Acquisition %s/%s", k + 1, self.sequence.parameter.num_averages) # Start masurement card operations self.rx_card.start_operation() time.sleep(0.01) self.tx_card.start_operation(self.sequence) # Get start time of acquisition time_start = time.time() while (num_gates := len(self.rx_card.rx_data)) < self.sequence.adc_count or num_gates == 0: # Delay poll by 10 ms time.sleep(0.01) if (time.time() - time_start) > timeout: # Could not receive all the data before timeout self.log.warning( "Acquisition Timeout: Only received %s/%s adc events", num_gates, self.sequence.adc_count ) break if num_gates >= self.sequence.adc_count and num_gates > 0: break if num_gates > 0: self.post_processing(self.sequence.parameter) self.tx_card.stop_operation() self.rx_card.stop_operation() if self.sequence.parameter.averaging_delay > 0: time.sleep(self.sequence.parameter.averaging_delay) # Reset gradient offset values self.tx_card.set_gradient_offsets(Dimensions(x=0, y=0, z=0), self.seq_provider.high_impedance[1:]) try: # if len(self._raw) != parameter.num_averages: if not all(gate.shape[0] == self.sequence.parameter.num_averages for gate in self._raw): raise ValueError( "Missing averages: %s/%s", [gate.shape[0] for gate in self._raw], self.sequence.parameter.num_averages, ) except ValueError as err: self.log.exception(err, exc_info=True) raise err return AcquisitionData( _raw=self._raw, unprocessed_data=self._unproc, sequence=self.seq_provider, session_path=self.session_path, meta={ self.tx_card.__name__: self.tx_card.dict(), self.rx_card.__name__: self.rx_card.dict(), self.seq_provider.__name__: self.seq_provider.dict() }, dwell_time=self.sequence.parameter.decimation / self.f_spcm, acquisition_parameters=self.sequence.parameter, )
[docs] def post_processing(self, parameter: AcquisitionParameter) -> None: """Proces acquired NMR data. Data is sorted according to readout size which might vary between different reout windows. Unprocessed and raw data are stored in class attributes _raw and _unproc. Both attributes are list, which store numpy arrays of readout data with the same number of readout sample points. Post processing contains the following steps (per readout sample size): (1) Extraction of reference signal and scaling to float values [mV] (2) Concatenate reference data and signal data in coil dimensions (3) Demodulation along readout dimensions (4) Decimation along readout dimension (5) Phase correction with reference signal Dimensions: [averages, coils, phase encoding, readout] Reference signal is stored in the last entry of the coil dimension. Parameters ---------- parameter Acquisition parameter """ readout_sizes = [data.shape[-1] for data in self.rx_card.rx_data] grouped_gates: dict[int, list] = { readout_sizes[k]: [] for k in sorted(np.unique(readout_sizes, return_index=True)[1]) } for data in self.rx_card.rx_data: grouped_gates[data.shape[-1]].append(data) gate_lengths = [np.stack(group, axis=1) for group in grouped_gates.values()] raw_size = len(self._raw) # Define channel dependent scaling scaling = np.expand_dims(self.rx_card.rx_scaling[:self.rx_card.num_channels.value], axis=(-1, -2)) for k, data in enumerate(gate_lengths): # Extract digital reference signal from channel 0 _ref = (data[0, ...].astype(np.uint16) >> 15).astype(float)[None, ...] # Remove digital signal from channel 0 data[0, ...] = data[0, ...] << 1 data = data.astype(np.int16) * scaling # Stack signal and reference in coil dimension data = np.concatenate((data, _ref), axis=0) # Append unprocessed data without post processing (last coil dimension entry contains reference) if raw_size > 0: self._unproc[k] = np.concatenate((self._unproc[k], data[None, ...]), axis=0) else: self._unproc.append(data[None, ...]) print("Demodulation at freq.:", parameter.larmor_frequency) # Demodulation and decimation data = data * np.exp(2j * np.pi * np.arange(data.shape[-1]) * parameter.larmor_frequency / self.f_spcm) # Always decimate the reference signal with moving average filter ref_dec = ddc.filter_moving_average(data[-1, ...], decimation=parameter.decimation, overlap=8)[None, ...] # Extract the demodulated signal data data = data[:-1, ...] # Switch case for DDC function match parameter.ddc_method: case DDCMethod.CIC: data = ddc.filter_cic_fir_comp(data, decimation=parameter.decimation, number_of_stages=5) case DDCMethod.AVG: data = ddc.filter_moving_average(data, decimation=parameter.decimation, overlap=8) case _: # Default case is FIR decimation data = signal.decimate(data, q=parameter.decimation, ftype="fir") # Apply phase correction with mean value # data = data * np.exp(-1j * np.mean(np.angle(ref_dec), axis = -1))[..., None] data = data * np.exp(-1j * np.angle(ref_dec)) # Append to global raw data list if raw_size > 0: self._raw[k] = np.concatenate((self._raw[k], data[None, ...]), axis=0) else: self._raw.append(data[None, ...])