"""Acquisition Control Class."""
import logging
import logging.config
import os
import time
from datetime import datetime
from pathlib import Path
import numpy as np
from scipy import signal
import console
from console.interfaces.acquisition_data import AcquisitionData
from console.interfaces.acquisition_parameter import AcquisitionParameter, DDCMethod
from console.interfaces.dimensions import Dimensions
from console.interfaces.unrolled_sequence import UnrolledSequence
from console.pulseq_interpreter.sequence_provider import Sequence, SequenceProvider
from console.spcm_control.rx_device import RxCard
from console.spcm_control.tx_device import TxCard
from console.utilities import ddc
from console.utilities.load_config import get_instances
LOG_LEVELS = [
logging.DEBUG,
logging.INFO,
logging.WARNING,
logging.ERROR,
logging.CRITICAL,
]
[docs]
class AcquisitionControl:
"""Acquisition control class.
The main functionality of the acquisition control is to orchestrate transmit and receive cards using
``TxCard`` and ``RxCard`` instances.
"""
def __init__(
self,
configuration_file: str,
nexus_data_dir: str = os.path.join(Path.home(), "nexus-console"),
file_log_level: int = logging.INFO,
console_log_level: int = logging.INFO,
):
"""Construct acquisition control class.
Create instances of sequence provider, tx and rx card.
Setup the measurement cards and get parameters required for a measurement.
Parameters
----------
configuration_file
Path to configuration yaml file which is used to create measurement card and sequence
provider instances.
nexus_data_dir:
Nexus console default directory to store logs, states and acquisition data.
If none, the default directory is create in the home directory, default is None.
file_log_level
Set the logging level for log file. Logfile is written to the session folder.
console_log_level
Set the logging level for the terminal/console output.
"""
# Create session path (contains all acquisitions of one day)
session_folder_name = datetime.now().strftime("%Y-%m-%d") + "-session/"
self.session_path = os.path.join(nexus_data_dir, session_folder_name)
os.makedirs(self.session_path, exist_ok=True)
self._setup_logging(console_level=console_log_level, file_level=file_log_level)
self.log = logging.getLogger("AcqCtrl")
self.log.info("--- Acquisition control started\n")
# Define global acquisition parameter object
try:
console.parameter = AcquisitionParameter.load(nexus_data_dir)
except FileNotFoundError as exc:
self.log.warning("Acquisition parameter state could not be loaded from dir: %s.\
Creating new acquisition parameter object.", exc)
console.parameter = AcquisitionParameter()
console.parameter.save_on_mutation = True
# Store parameter hash to detect when a sequence needs to be recalculated
self._current_parameter_hash: int = hash(console.parameter)
# Get instances from configuration file
ctx = get_instances(configuration_file)
self.seq_provider: SequenceProvider = ctx[0]
self.tx_card: TxCard = ctx[1]
self.rx_card: RxCard = ctx[2]
self.seq_provider.output_limits = self.tx_card.max_amplitude
# Setup the cards
self.is_setup: bool = False
if self.tx_card.connect() and self.rx_card.connect():
self.log.info("Setup of measurement cards successful.")
self.is_setup = True
# Get the rx sampling rate for DDC
self.f_spcm = self.rx_card.sample_rate * 1e6
# Set sequence provider max. amplitude per channel according to values from tx_card
self.seq_provider.max_amp_per_channel = self.tx_card.max_amplitude
self.unrolled_seq: UnrolledSequence | None = None
# Attributes for data and dwell time of downsampled signal
self._raw: list[np.ndarray] = []
self._unproc: list[np.ndarray] = []
def __del__(self):
"""Class destructor disconnecting measurement cards."""
if self.tx_card:
self.tx_card.disconnect()
if self.rx_card:
self.rx_card.disconnect()
self.log.info("Measurement cards disconnected")
self.log.info("\n--- Acquisition control terminated\n\n")
def _setup_logging(self, console_level: int, file_level: int) -> None:
# Check if log levels are valid
if console_level not in LOG_LEVELS:
raise ValueError("Invalid console log level")
if file_level not in LOG_LEVELS:
raise ValueError("Invalid file log level")
# Disable existing loggers
logging.config.dictConfig({"version": 1, "disable_existing_loggers": True}) # type: ignore[attr-defined]
# Set up logging to file
logging.basicConfig(
level=file_level,
format="%(asctime)s %(name)-7s: %(levelname)-8s >> %(message)s",
datefmt="%d-%m-%Y, %H:%M",
filename=f"{self.session_path}console.log",
filemode="a",
)
# Define a Handler which writes INFO messages or higher to the sys.stderr
console = logging.StreamHandler()
console.setLevel(console_level)
formatter = logging.Formatter("%(name)-7s: %(levelname)-8s >> %(message)s")
console.setFormatter(formatter)
logging.getLogger("").addHandler(console)
[docs]
def set_sequence(self, sequence: str | Sequence) -> None:
"""Set sequence and acquisition parameter.
Parameters
----------
sequence
Path to pulseq sequence file.
parameter
Set of acquisition parameters which are required for the acquisition.
Raises
------
AttributeError
Invalid sequence provided.
FileNotFoundError
Invalid file ending of sequence file.
"""
try:
# Check sequence
if isinstance(sequence, Sequence):
self.seq_provider.from_pypulseq(sequence)
elif isinstance(sequence, str):
if not sequence.endswith(".seq"):
raise FileNotFoundError("Invalid sequence file.")
self.seq_provider.read(sequence)
except (FileNotFoundError, AttributeError) as err:
self.log.exception(err, exc_info=True)
raise err
# Reset unrolled sequence
self.unrolled_seq = None
self.log.info(
"Unrolling sequence: %s",
self.seq_provider.definitions["Name"].replace(" ", "_"),
)
# Update sequence parameter hash and calculate sequence
self._current_parameter_hash = hash(console.parameter)
self.unrolled_seq = self.seq_provider.unroll_sequence()
self.log.info("Sequence duration: %s s", self.unrolled_seq.duration)
[docs]
def run(self) -> AcquisitionData:
"""Run an acquisition job.
Raises
------
RuntimeError
The measurement cards are not setup properly
ValueError
Missing raw data or missing averages
"""
try:
# Check setup
if not self.is_setup:
raise RuntimeError("Measurement cards are not setup.")
if self.unrolled_seq is None:
raise ValueError("No sequence set, call set_sequence() to set a sequence and acquisition parameter.")
except (RuntimeError, ValueError) as err:
self.log.exception(err, exc_info=True)
raise err
if self._current_parameter_hash != hash(console.parameter):
# Redo sequence unrolling in case acquisition parameters changed, i.e. different hash
self.unrolled_seq = None
self.log.info(
"Unrolling sequence: %s", self.seq_provider.definitions["Name"].replace(" ", "_")
)
# Update acquisition parameter hash value
self._current_parameter_hash = hash(console.parameter)
self.unrolled_seq = self.seq_provider.unroll_sequence()
self.log.info("Sequence duration: %s s", self.unrolled_seq.duration)
# Define timeout for acquisition process: 5 sec + sequence duration
timeout = 5 + self.unrolled_seq.duration
self._unproc = []
self._raw = []
# Set gradient offset values
self.tx_card.set_gradient_offsets(console.parameter.gradient_offset, self.seq_provider.high_impedance[1:])
for k in range(console.parameter.num_averages):
self.log.info("Acquisition %s/%s", k + 1, console.parameter.num_averages)
# Start masurement card operations
self.rx_card.start_operation()
time.sleep(0.01)
self.tx_card.start_operation(self.unrolled_seq)
# Get start time of acquisition
time_start = time.time()
while (num_gates := len(self.rx_card.rx_data)) < self.unrolled_seq.adc_count or num_gates == 0:
# Delay poll by 10 ms
time.sleep(0.01)
if (time.time() - time_start) > timeout:
# Could not receive all the data before timeout
self.log.warning(
"Acquisition Timeout: Only received %s/%s adc events",
num_gates, self.unrolled_seq.adc_count
)
break
if num_gates >= self.unrolled_seq.adc_count and num_gates > 0:
break
if num_gates > 0:
self.post_processing(console.parameter)
self.tx_card.stop_operation()
self.rx_card.stop_operation()
if console.parameter.averaging_delay > 0:
time.sleep(console.parameter.averaging_delay)
# Reset gradient offset values
self.tx_card.set_gradient_offsets(Dimensions(x=0, y=0, z=0), self.seq_provider.high_impedance[1:])
try:
# if len(self._raw) != parameter.num_averages:
if not all(gate.shape[0] == console.parameter.num_averages for gate in self._raw):
raise ValueError(
"Missing averages: %s/%s",
[gate.shape[0] for gate in self._raw],
console.parameter.num_averages,
)
except ValueError as err:
self.log.exception(err, exc_info=True)
raise err
return AcquisitionData(
_raw=self._raw,
unprocessed_data=self._unproc,
sequence=self.seq_provider,
session_path=self.session_path,
meta={
self.tx_card.__name__: self.tx_card.dict(),
self.rx_card.__name__: self.rx_card.dict(),
self.seq_provider.__name__: self.seq_provider.dict()
},
dwell_time=console.parameter.decimation / self.f_spcm,
acquisition_parameters=console.parameter,
)
[docs]
def post_processing(self, parameter: AcquisitionParameter) -> None:
"""Proces acquired NMR data.
Data is sorted according to readout size which might vary between different reout windows.
Unprocessed and raw data are stored in class attributes _raw and _unproc.
Both attributes are list, which store numpy arrays of readout data with the same number
of readout sample points.
Post processing contains the following steps (per readout sample size):
(1) Extraction of reference signal and scaling to float values [mV]
(2) Concatenate reference data and signal data in coil dimensions
(3) Demodulation along readout dimensions
(4) Decimation along readout dimension
(5) Phase correction with reference signal
Dimensions: [averages, coils, phase encoding, readout]
Reference signal is stored in the last entry of the coil dimension.
Parameters
----------
parameter
Acquisition parameter
"""
readout_sizes = [data.shape[-1] for data in self.rx_card.rx_data]
grouped_gates: dict[int, list] = {
readout_sizes[k]: [] for k in sorted(np.unique(readout_sizes, return_index=True)[1])
}
for data in self.rx_card.rx_data:
grouped_gates[data.shape[-1]].append(data)
gate_lengths = [np.stack(group, axis=1) for group in grouped_gates.values()]
raw_size = len(self._raw)
# Define channel dependent scaling
scaling = np.expand_dims(self.rx_card.rx_scaling[:self.rx_card.num_channels.value], axis=(-1, -2))
for k, data in enumerate(gate_lengths):
# Extract digital reference signal from channel 0
_ref = (data[0, ...].astype(np.uint16) >> 15).astype(float)[None, ...]
# Remove digital signal from channel 0
data[0, ...] = data[0, ...] << 1
data = data.astype(np.int16) * scaling
# Stack signal and reference in coil dimension
data = np.concatenate((data, _ref), axis=0)
# Append unprocessed data without post processing (last coil dimension entry contains reference)
if raw_size > 0:
self._unproc[k] = np.concatenate((self._unproc[k], data[None, ...]), axis=0)
else:
self._unproc.append(data[None, ...])
print("Demodulation at freq.:", parameter.larmor_frequency)
# Demodulation and decimation
data = data * np.exp(2j * np.pi * np.arange(data.shape[-1]) * parameter.larmor_frequency / self.f_spcm)
# Always decimate the reference signal with moving average filter
ref_dec = ddc.filter_moving_average(data[-1, ...], decimation=parameter.decimation, overlap=8)[None, ...]
# Extract the demodulated signal data
data = data[:-1, ...]
# Switch case for DDC function
match console.parameter.ddc_method:
case DDCMethod.CIC:
data = ddc.filter_cic_fir_comp(data, decimation=parameter.decimation, number_of_stages=5)
case DDCMethod.AVG:
data = ddc.filter_moving_average(data, decimation=parameter.decimation, overlap=8)
case _:
# Default case is FIR decimation
data = signal.decimate(data, q=parameter.decimation, ftype="fir")
# Apply phase correction with mean value
# data = data * np.exp(-1j * np.mean(np.angle(ref_dec), axis = -1))[..., None]
data = data * np.exp(-1j * np.angle(ref_dec))
# Append to global raw data list
if raw_size > 0:
self._raw[k] = np.concatenate((self._raw[k], data[None, ...]), axis=0)
else:
self._raw.append(data[None, ...])