"""Acquisition Control Class."""
import copy
import logging
import logging.config
import os
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
import numpy as np
from console.interfaces.acquisition_data import AcquisitionData
from console.interfaces.acquisition_parameter import AcquisitionParameter
from console.interfaces.dimensions import Dimensions
from console.interfaces.unrolled_sequence import UnrolledSequence
from console.pulseq_interpreter.sequence_provider import Sequence, SequenceProvider
from console.spcm_control.rx_device import RxCard
from console.spcm_control.tx_device import TxCard
from console.utilities.load_config import get_instances
LOG_LEVELS = [
logging.DEBUG,
logging.INFO,
logging.WARNING,
logging.ERROR,
logging.CRITICAL,
]
[docs]
class AcquisitionControl:
"""Acquisition control class.
The main functionality of the acquisition control is to orchestrate transmit and receive cards using
``TxCard`` and ``RxCard`` instances.
"""
def __init__(
self,
configuration_file: str,
nexus_data_dir: str = os.path.join(Path.home(), "nexus-console"),
file_log_level: int = logging.INFO,
console_log_level: int = logging.INFO,
):
"""Construct acquisition control class.
Create instances of sequence provider, tx and rx card.
Setup the measurement cards and get parameters required for a measurement.
Parameters
----------
configuration_file
Path to configuration yaml file which is used to create measurement card and sequence
provider instances.
nexus_data_dir:
Nexus console default directory to store logs, states and acquisition data.
If none, the default directory is create in the home directory, default is None.
file_log_level
Set the logging level for log file. Logfile is written to the session folder.
console_log_level
Set the logging level for the terminal/console output.
"""
# Create session path (contains all acquisitions of one day)
session_folder_name = datetime.now().strftime("%Y-%m-%d") + "-session/"
self.session_path = os.path.join(nexus_data_dir, session_folder_name)
os.makedirs(self.session_path, exist_ok=True)
self._setup_logging(console_level=console_log_level, file_level=file_log_level)
self.log = logging.getLogger("AcqCtrl")
self.log.info("--- Acquisition control started\n")
# Get instances from configuration file
ctx = get_instances(configuration_file)
self.seq_provider: SequenceProvider = ctx[0]
self.tx_card: TxCard = ctx[1]
self.rx_card: RxCard = ctx[2]
self.seq_provider.output_limits = self.tx_card.max_amplitude
# Setup the cards
self.is_setup: bool = False
try:
if self.tx_card.connect() and self.rx_card.connect():
self.log.info("Setup of measurement cards successful.")
self.is_setup = True
except Exception:
self.log.exception("Error during card connection.")
if self.tx_card:
self.tx_card.disconnect()
if self.rx_card:
self.rx_card.disconnect()
# Get the rx sampling rate for DDC
self.f_spcm = self.rx_card.sample_rate * 1e6
# Set sequence provider max. amplitude per channel according to values from tx_card
self.seq_provider.max_amp_per_channel = self.tx_card.max_amplitude
self.sequence: UnrolledSequence | None = None
# Attributes for data and dwell time of downsampled signal
self._raw: list[np.ndarray] = []
self._unproc: list[np.ndarray] = []
def __del__(self):
"""Class destructor disconnecting measurement cards."""
if self.tx_card:
self.tx_card.disconnect()
if self.rx_card:
self.rx_card.disconnect()
self.log.info("Measurement cards disconnected")
self.log.info("Acquisition control terminated\n---------------------------------------------------\n")
def _setup_logging(self, console_level: int, file_level: int) -> None:
# Check if log levels are valid
if console_level not in LOG_LEVELS:
raise ValueError("Invalid console log level")
if file_level not in LOG_LEVELS:
raise ValueError("Invalid file log level")
# Disable existing loggers
logging.config.dictConfig({"version": 1, "disable_existing_loggers": True}) # type: ignore[attr-defined]
# Set up logging to file
logging.basicConfig(
level=file_level,
format="%(asctime)s %(name)-7s: %(levelname)-8s >> %(message)s",
datefmt="%d-%m-%Y, %H:%M",
filename=f"{self.session_path}console.log",
filemode="a",
)
# Define a Handler which writes INFO messages or higher to the sys.stderr
log_console = logging.StreamHandler()
log_console.setLevel(console_level)
formatter = logging.Formatter("%(name)-7s: %(levelname)-8s >> %(message)s")
log_console.setFormatter(formatter)
logging.getLogger("").addHandler(log_console)
[docs]
def set_sequence(self, sequence: str | Sequence, parameter: AcquisitionParameter) -> None:
"""Set sequence and acquisition parameter.
Parameters
----------
sequence
Path to pulseq sequence file.
parameter
Set of acquisition parameters which are required for the acquisition.
Raises
------
AttributeError
Invalid sequence provided.
FileNotFoundError
Invalid file ending of sequence file.
"""
try:
# Check sequence
if isinstance(sequence, Sequence):
self.seq_provider.from_pypulseq(sequence)
elif isinstance(sequence, str):
if not sequence.endswith(".seq"):
raise FileNotFoundError("Invalid sequence file.")
self.seq_provider.read(sequence)
except (FileNotFoundError, AttributeError) as err:
self.log.exception(err, exc_info=True)
raise err
# Reset unrolled sequence
self.sequence = None
self.log.info(
"Unrolling sequence: %s",
self.seq_provider.definitions["Name"].replace(" ", "_"),
)
# Update sequence parameter hash and calculate sequence
self.sequence = self.seq_provider.unroll_sequence(parameter=parameter)
self.log.info("Sequence duration: %s s", self.sequence.duration)
[docs]
def run(self, store_unprocessed: bool = False) -> AcquisitionData:
"""Run an acquisition job.
Parameters
----------
store_unprocessed
Flag for whether to keep the raw, undecimated data after decimation
realtime_proccessing
flag for processing the data in real time using the multiprocessing or
using threading to process the data after it has all been acquired.
Raises
------
RuntimeError
The measurement cards are not setup properly
ValueError
Missing raw data or missing averages
"""
try:
# Check setup
if not self.is_setup:
raise RuntimeError("Measurement cards are not setup.")
if self.sequence is None:
raise ValueError("No sequence set, call set_sequence() to set a sequence and acquisition parameter.")
except (RuntimeError, ValueError) as err:
self.log.exception(err, exc_info=True)
raise err
# Define timeout for acquisition process: 5 sec + sequence duration
timeout = 5 + self.sequence.duration
self.store_unprocessed = store_unprocessed
# Create a list to store rx_data for all averages
self.receive_data: list = []
self.num_adc_events = len(self.sequence.rx_data)
# Set gradient offset values
self.tx_card.set_gradient_offsets(
self.sequence.parameter.gradient_offset, self.seq_provider.high_impedance[1:]
)
for k in range(self.sequence.parameter.num_averages):
# Create a copy of rx_data to store the current acquisition in and label scan number.
self.rx_card.rx_data = copy.deepcopy(self.sequence.rx_data)
self.log.info("Acquisition %s/%s", k + 1, self.sequence.parameter.num_averages)
# Start masurement card operations
self.rx_card.start_operation()
while not self.rx_card.is_receiving.is_set():
time.sleep(0.01)
# self.log.debug("Waiting for RX card to start receiving...")
self.tx_card.start_operation(self.sequence)
# Get start time of acquisition
time_start = time.time()
while (num_gates := self.rx_card.total_gates) < self.sequence.adc_count or num_gates == 0:
# Delay poll by 10 ms
time.sleep(0.01)
if (time.time() - time_start) > timeout:
# Could not receive all the data before timeout
self.log.warning(
"Acquisition Timeout: Only received %s/%s adc events",
num_gates, self.sequence.adc_count
)
break
if num_gates >= self.sequence.adc_count and num_gates > 0:
break
# Append the receive data with current scan data
scan_data: list = self.rx_card.rx_data.copy()
self.rx_card.rx_data = None
for data in scan_data:
data.average_index = k
self.receive_data.extend(scan_data)
self.tx_card.stop_operation()
self.rx_card.stop_operation()
if self.sequence.parameter.averaging_delay > 0:
time.sleep(self.sequence.parameter.averaging_delay)
# Reset gradient offset values
self.tx_card.set_gradient_offsets(Dimensions(x=0, y=0, z=0), self.seq_provider.high_impedance[1:])
if len(self.receive_data) > 0:
self.log.debug(f"Total number of ADC events: {len(self.receive_data)}")
# Process all the data at the end of the acquisition
self.post_processing(self.sequence.parameter)
else:
raise RuntimeError("No ADC events present")
try:
averages = [data.average_index for data in self.receive_data]
if not (np.unique(averages).size == self.sequence.parameter.num_averages):
averages_idc = np.arange(self.sequence.parameter.num_averages)
missing_averages = [avg + 1 for avg in averages_idc if avg not in averages]
raise ValueError(f"Missing averages: {missing_averages} out of {self.sequence.parameter.num_averages}")
except ValueError as err:
self.log.exception(err, exc_info=True)
raise err
return AcquisitionData(
receive_data=self.receive_data,
sequence=self.seq_provider.to_pypulseq(),
session_path=self.session_path,
meta={
self.tx_card.__name__: self.tx_card.dict(),
self.rx_card.__name__: self.rx_card.dict(),
self.seq_provider.__name__: self.seq_provider.dict()
},
acquisition_parameters=self.sequence.parameter,
)
[docs]
def post_processing(self, parameter: AcquisitionParameter) -> None:
"""Proces acquired NMR data.
Post processing contains the following steps (per readout sample size):
(1) Scaling of receive data
(2) Demodulation along readout dimensions
(3) Decimation along readout dimension
Parameters
----------
parameter
Acquisition parameter
"""
# Set the larmor frequency for all data to the defined larmor_frequency
for rx_data in self.receive_data:
rx_data.larmor_frequency = parameter.larmor_frequency
# Process the data in parallel
with ThreadPoolExecutor() as executor:
executor.map(lambda rx_obj: rx_obj.process_data(store_unprocessed=self.store_unprocessed)
, self.receive_data)