import os
import numpy as np
import torch
import pydicom
import matplotlib.pyplot as plt
from tcia_utils import nbia

from monai.data import CacheDataset
from monai.bundle import ConfigParser, download
from monai.transforms import LoadImage, LoadImaged, Orientation, Orientationd, EnsureChannelFirst, EnsureChannelFirstd, Compose
#from rt_utils import RTStructBuilder
import json

#Config
# The data will be downloaded here
DATA_DEST_DIR = 'C:\\Users\\U1\\Desktop\\ct_slicer\\data\\'
CT_FILE_ID = '1.3.6.1.4.1.14519.5.2.1.3320.3273.193828570195012288011029757668' #The name of the DICOM repository.
DOWNLOAD_DATA = False #Whether to download the specified CT or MRI file.

SLICE_FILE_NAME = '1-394.dcm'
GENERATE_NUMPY_CT_SAVE = False

GENERATE_MASK_IMAGE = False #Whether to generate the mask image. Takes very long. Scale it down if it is draining.

SPACING_COEFF = 10 #How much should a single pixel contribute to the volume of an organ.

model_name="wholeBody_ct_segmentation" #The monai ID of the model which is being used
NEEDS_DOWNLOAD_MODEL = False #Whether to download the model in the next run.

slice_idx = 250 #Which slice should be present starting from 0.

model_path = os.path.join(DATA_DEST_DIR,model_name,'models','model_lowres.pt')
config_path = os.path.join(DATA_DEST_DIR,model_name,'configs','inference.json') #Here, modify the output_dir to the "$'C:/Users/U1/Desktop/ct_slicer/data/mask'"
#If you don't have a super computer, set the displayable_configs's highres key to false value. Also default the device to "$'cpu'"

#CODE
CT_folder = os.path.join(DATA_DEST_DIR, CT_FILE_ID)

if DOWNLOAD_DATA == True:
    cart_name = "nbia-56561691129779503"
    cart_data = nbia.getSharedCart(cart_name)
    df = nbia.downloadSeries(cart_data, format="df", path = DATA_DEST_DIR)


if GENERATE_NUMPY_CT_SAVE == True:

    '''ds = pydicom.read_file(os.path.join(CT_folder, SLICE_FILE_NAME)) #Contains the scan info
    image = ds.pixel_array
    #Convert from 1D to 2D scaled image
    image = ds.RescaleSlope * image + ds.RescaleIntercept'''

    image_loader = LoadImage(image_only = True)
    CT = image_loader(CT_folder) #CT.meta has metadata
    #Original_affine and affine has the info of transforming the image data to 0,0,0
    channel_transform = EnsureChannelFirst(channel_dim="no_channel")
    CT = channel_transform(CT)
    #Transform orientation to not-upside-down.
    orientation_transform = Orientation(axcodes=('LPS')) #Superior-Posterior-Left positioning
    CT = orientation_transform(CT)

    CT_coronal_slice = CT[0,:,256].cpu().numpy() #0 is channel index

    
    np.save("CT_meta.npy", CT.meta, allow_pickle=True)
    np.save("CT_raw.npy", CT_coronal_slice)

#Slicer model from: https://project-monai.github.io/model-zoo.html#/model/wholeBody_ct_segmentation

if NEEDS_DOWNLOAD_MODEL == True:
    download(name=model_name,bundle_dir=DATA_DEST_DIR) #Monai download

segmentation = {}

if GENERATE_MASK_IMAGE == True:
    config = ConfigParser()
    config.read_config(config_path)

    preprocessing = config.get_parsed_content("preprocessing")


    model = config.get_parsed_content("network")
    
    torch.set_num_threads(8)
    torch.backends.mkldnn.enabled = True

    model.load_state_dict(
        torch.load(model_path, map_location=torch.device("cpu"))
    )
    
    model.eval();

    #Inferer takes in data and the model and returns the model's output
    inferer = config.get_parsed_content("inferer") #Slice them to 96x96x96 chunks

    postprocessing = config.get_parsed_content("postprocessing")

    datalist = [CT_folder]
    config["datalist"] = datalist
    dataloader = config.get_parsed_content("dataloader")

    #Prediction pipeline
    #data = preprocessing({'image': CT_folder}) #scaled already
    dataset = CacheDataset(
        data=[{'image': CT_folder}],
        transform=preprocessing,
        cache_rate=1.0,
        num_workers=0
    )

    data = dataset[0]
    
    with torch.no_grad(): #No back propagation
        data['pred'] = inferer(data['image'].unsqueeze(0), network = model) #Chunks

    #Remove 1 dimension
    data['pred'] = data['pred'][0]
    data['image'] = data['image'][0]

    #Apply postprocessing
    data = postprocessing(data)

    segmentation = torch.flip(data['pred'][0], dims = [2])
    segmentation = segmentation.cpu().numpy()
    np.save("segmentation.npy", segmentation)

CT_coronal_slice={}
CT_meta={}

try:
    segmentation = np.load("segmentation.npy")
    CT_coronal_slice = np.load("CT_raw.npy")
    CT_meta = np.load("CT_meta.npy", allow_pickle=True).item()
except:
    print("Failed to display! Some data was not present in a numpy format. Set their flags to true in the configs and re-run the program.")

image2=CT_coronal_slice.T
#CT_coronal_slice = CT[0,:,slice_idx].cpu().numpy() #0 is channel index
segmentation_coronal_slice = segmentation[:,slice_idx]

image=segmentation_coronal_slice.T


fig, axes = plt.subplots(
    1, 2,
    figsize=(9, 8),
    num=f"Sliced image – file: {SLICE_FILE_NAME}, index: {slice_idx}"
)

# CT slice
im0 = axes[0].pcolormesh(image2, cmap="Greys_r")
axes[0].set_title("CT/MRI slice")
axes[0].axis("off")
cbar0 = fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
cbar0.set_label("Hounsfield Units")

# Segmentation slice
im1 = axes[1].pcolormesh(image, cmap="nipy_spectral")
axes[1].set_title("Segmentation")
axes[1].axis("off")
cbar1 = fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
cbar1.set_label("Label")

plt.tight_layout()
plt.show()

#Compute bladder size, example

def calculate_organ_volume(organ_id):
    number_o_voxels = (segmentation==organ_id).sum().item()
    voxel_volume_cm = np.prod(CT_meta['spacing']/SPACING_COEFF)
    o_volume = number_o_voxels * voxel_volume_cm
    print(f'Organ volume with id: {organ_id}: {o_volume:.3f}cm^3')
    
calculate_organ_volume(104)