Source code for Hive_ML_split_by_class

#!/usr/bin/env python

import os
import pandas as pd
import ast
import shutil
from argparse import ArgumentParser, RawTextHelpFormatter
from pathlib import Path
from textwrap import dedent
import importlib
import json
import Hive_ML.configs

DESC = dedent("""
    Script to split the dataset by the given class label.
    """)
EPILOG = dedent("""
    Example call:
    ::
        {filename} --root-dir /path/to/MAMA-MIA
    """.format(filename=Path(__file__).name))


[docs] def get_arg_parser(): pars = ArgumentParser(description=DESC, epilog=EPILOG, formatter_class=RawTextHelpFormatter) pars.add_argument( "--root-dir", type=str, required=True, help="Root directory containing MAMA-MIA dataset (should contain 'images', 'segmentations', and clinical_and_imaging_info.xlsx)", ) pars.add_argument( "--config-file", type=str, required=True, help="Config file containing the configuration for the dataset.", ) return pars
[docs] def main(): parser = get_arg_parser() args = parser.parse_args() root_dir = args.root_dir data_dir = os.path.join(root_dir, "images") segmentation_dir = os.path.join(root_dir, "segmentations", "expert") clinical_data = pd.read_excel(os.path.join(root_dir, "clinical_and_imaging_info.xlsx")) try: with open(args.config_file) as json_file: config_dict = json.load(json_file) except FileNotFoundError: with importlib.resources.path(Hive_ML.configs, args.config_file) as json_path: with open(json_path) as json_file: config_dict = json.load(json_file) # Build data dictionary data_dict = {} for sub in os.listdir(data_dir): if os.path.isdir(os.path.join(data_dir, sub)): data_dict[sub] = {"images": [], "timepoints": []} for img in os.listdir(os.path.join(data_dir, sub)): if img.endswith(".nii.gz"): data_dict[sub]["images"].append(os.path.join(data_dir, sub, img)) data_dict[sub]["images"].sort() acq_times = clinical_data[clinical_data["patient_id"] == sub]["acquisition_times"].values[0] class_label = clinical_data[clinical_data["patient_id"] == sub]["pcr"].values[0] if isinstance(acq_times, str): data_dict[sub]["timepoints"] = ast.literal_eval(acq_times) else: data_dict[sub]["timepoints"] = acq_times data_dict[sub]["segmentation"] = os.path.join(segmentation_dir, f"{sub}.nii.gz") data_dict[sub]["label"] = class_label for label in config_dict["label_dict"].keys(): Path(root_dir).joinpath(config_dict["label_dict"][label]).mkdir(parents=True, exist_ok=True) for sub in data_dict.keys(): label = data_dict[sub]["label"] try: output_dir = Path(root_dir).joinpath(config_dict["label_dict"][str(int(label))], sub) except ValueError: continue Path(output_dir).mkdir(parents=True, exist_ok=True) shutil.copy(Path(data_dict[sub]["segmentation"]), output_dir.joinpath(f"{sub}_mask.nii.gz")) shutil.copy( Path(root_dir).joinpath("4D_images", sub, f"{sub}.nii.gz"), output_dir.joinpath(f"{sub}_image.nii.gz") )
if __name__ == "__main__": main()