# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT

from abc import ABC, abstractmethod
import atexit
import os
from pathlib import Path
from typing import List
import natsort
import numpy as np
from tqdm import tqdm

import multiprocess as mp
from cli.serdes import DataViewsDeserializer, DataViewsSerializer
from cli.writers.views import DataMerger
from mpp.core.views import ViewType
from mpp.parsers.data_parser import Partition


class _DataProcessor(ABC):

    def __init__(self, data_accumulator, view_generator, view_writer):
        self.data_accumulator = data_accumulator
        self.view_generator = view_generator
        self.view_writer = view_writer
        self.detail_views = None

    @abstractmethod
    def process(self, event_reader, partitions: List[Partition], parallel_cores: int = 1, no_detail_views: bool = False):
        pass


class _SerialDataProcessor(_DataProcessor):

    def process(self, event_reader, partitions: List[Partition], parallel_cores: int = 1, no_detail_views: bool = False):
        # TODO: refactor main with class that has Parallel and Serial Processing classes (and a factory)
        pbar = tqdm(partitions)
        for idx, partition in enumerate(pbar):
            pbar.set_description(f'Processing partition {idx + 1} out of {len(partitions)}')
            event_df, first_sample, last_sample = self.__get_dataframe_from_partition(event_reader, partition)
            self.data_accumulator = self.process_dataframe(event_df, first_sample, last_sample)

    def process_dataframe(self, event_df, first_sample, last_sample):
        # Generate partial detail views for the partition and write to storage
        event_aggregates = self.view_generator.compute_aggregates(event_df)
        self.data_accumulator.update_aggregates(event_aggregates)
        self.handle_detail_views(event_df, first_sample, last_sample)
        if not self.detail_views:
            self.data_accumulator.update_statistics(df=event_df)
        return self.data_accumulator

    def handle_detail_views(self, event_df, first_sample, last_sample):
        self.detail_views = self.view_generator.generate_detail_views(event_df)
        if self.detail_views:
            self.view_writer.write(list(self.detail_views.values()), first_sample, last_sample)
            self.data_accumulator.update_statistics(self.detail_views)

    @staticmethod
    def __get_dataframe_from_partition(event_reader, partition: 'Partition'):
        event_df = next(event_reader(partition=partition, chunk_size=0))
        return event_df, partition.first_sample, partition.last_sample


class _ParallelDataProcessor(_DataProcessor):

    def __init__(self, data_accumulator, view_generator, view_writer):
        super().__init__(data_accumulator, view_generator, view_writer)
        self.event_reader = None
        self.serializer = DataViewsSerializer()

    def process(self, event_reader, partitions: List[Partition], parallel_cores: int = 1, no_detail_views: bool = False):
        self.event_reader = event_reader
        MAX_NUMBER_OF_CORES = 60
        parallel = parallel_cores
        with self.serializer.parent_dir as tmp_dir:
            if parallel_cores and parallel_cores > MAX_NUMBER_OF_CORES:
                print(f'Warning: parallel processing on greater than {MAX_NUMBER_OF_CORES} cores is not currently '
                    f'supported')
            if not parallel_cores:
                parallel = None
                parallel_cores = mp.cpu_count()
            number_of_cores = np.min([mp.cpu_count(), parallel_cores, len(partitions), MAX_NUMBER_OF_CORES])
            partition_str = f' due to {len(partitions)} partitions in the input file' if number_of_cores == len(
                partitions) and len(partitions) != parallel_cores else ''
            parallel_option_str = ' (use -p to specify number of processes)' if not parallel else ''
            print(f'Processing in parallel with {number_of_cores} out of {mp.cpu_count()} processes' + partition_str
                + parallel_option_str + '...', end='', flush=True)
            atexit.register(self.serializer.cleanup)
            with mp.Pool(processes=number_of_cores) as pool:
                pool.map(self.process_partition, partitions)
                pool.close()
                pool.join()
            # Initialize a deserializer for views
            deserializer = DataViewsDeserializer()
            self.data_accumulator = self._handle_temp_detail_view_files(no_detail_views, deserializer, tmp_dir)
            self.data_accumulator = self._handle_temp_summary_view_files(deserializer, tmp_dir)

    def process_partition(self, partition: 'Partition'):
        event_reader = self.event_reader(partition=partition, chunk_size=0)
        event_df = next(event_reader)
        print('.', end='', flush=True)
        # Generate partial detail views for the partition and write to storage
        summary_computations = self.view_generator.compute_aggregates(event_df)
        self.serializer.write_views(summary_computations, partition=partition)
        detail_views = self.view_generator.generate_detail_views(event_df)
        if detail_views:
            self.serializer.write_views(list(detail_views.values()), partition=partition)

    def _handle_temp_detail_view_files(self, no_detail_views, deserializer, tmp_dir):
        # Deserialize and combine all partial detail views
        detail_view_files = list(filter(lambda x: '__' + ViewType.DETAILS.name + '__' in x, os.listdir(tmp_dir)))
        data_merger = DataMerger(self.view_writer)
        include_details = not no_detail_views
        print()
        pbar = tqdm(natsort.natsorted(detail_view_files))
        for idx, filename in enumerate(pbar):
            pbar.set_description(f'Writing partition {idx + 1} out of {len(pbar)} to CSV...')
            detail_views, partition = deserializer.read_views(Path(tmp_dir) / filename)
            if include_details:
                data_merger.write_to_detail_views(detail_views, partition)
            self.data_accumulator.update_statistics(detail_views)
        return self.data_accumulator

    def _handle_temp_summary_view_files(self, deserializer, tmp_dir):
        summary_view_files = list(filter(lambda x: '__' + ViewType.SUMMARY.name + '__' in x, os.listdir(tmp_dir)))
        for filename in natsort.natsorted(summary_view_files):
            summary_views, partition = deserializer.read_views(Path(tmp_dir) / filename)
            self.data_accumulator.update_aggregates(list(summary_views.values()))
        return self.data_accumulator


class DataProcessorFactory:

    data_processors = {
        'parallel': _ParallelDataProcessor,
        'serial': _SerialDataProcessor
    }

    def create(self, is_parallel, data_accumulator, view_generator, view_writer):
        if is_parallel:
            return self.data_processors['parallel'](data_accumulator, view_generator, view_writer)
        else:
            return self.data_processors['serial'](data_accumulator, view_generator, view_writer)
