import { TrainingReducer } from "modules/types/training";
import { PayloadAction } from "@reduxjs/toolkit";
import {
	MLFlowExperiment,
	MLFlowModelVersion,
	MLFlowRun,
} from "modules/types/training/ml-flow";

import { AirflowDagRun } from "modules/types/training/airflow";
import { getF1ScoreMetricName } from "common/services/ui.service";

export const setExperiments = (
	state: TrainingReducer,
	action: PayloadAction<{ experiments: MLFlowExperiment[] }>
): TrainingReducer => {
	let experimentMap: { [key: string]: MLFlowExperiment } = {
		...state.experimentMap,
	};

	for (let experiment of action.payload.experiments) {
		experimentMap[experiment.experiment_id] = experiment;
	}

	return {
		...state,
		experimentMap,
	};
};

export const setRuns = (
	state: TrainingReducer,
	action: PayloadAction<{ runs: MLFlowRun[] }>
): TrainingReducer => {
	let runMap: { [key: string]: MLFlowRun } = {
		...state.runMap,
	};
	let isMultiLabel = action.payload.runs[0].data.metrics?.find(
		(metric) => metric.key === "test_optimal_thresholds_empty_f1_score"
	)?.value
		? true
		: false;
	let metricName = getF1ScoreMetricName(isMultiLabel);
	let bestRun = action.payload.runs?.sort((a, b) => {
		let bAcc = b.data.metrics?.find(
				(metric) => metric.key === metricName
			)?.value,
			aAcc = a.data.metrics?.find(
				(metric) => metric.key === metricName
			)?.value;
		if (bAcc === undefined) bAcc = 0;
		if (aAcc === undefined) aAcc = 0;
		return bAcc - aAcc;
	})[0];

	for (let run of action.payload.runs) {
		runMap[run.info.run_id] = run;
		if (run.info.run_uuid === bestRun?.info.run_uuid) {
			runMap[run.info.run_id].isBestRun = true;
		}
	}

	return {
		...state,
		runMap,
	};
};

export const setDagRuns = (
	state: TrainingReducer,
	action: PayloadAction<{ dagRuns: AirflowDagRun[] }>
): TrainingReducer => {
	let dagRunMap: { [key: string]: AirflowDagRun } = {
		...state.dagRunMap,
	};

	for (let dagRun of action.payload.dagRuns) {
		dagRunMap[dagRun.dag_run_id] = dagRun;
	}

	return {
		...state,
		dagRunMap,
	};
};

export const setModelVersions = (
	state: TrainingReducer,
	action: PayloadAction<{
		projectId: string;
		modelVersions: MLFlowModelVersion[];
	}>
): TrainingReducer => {
	let modelVersionMap = {
		...state.modelVersionMap,
	};

	// console.log({
	// 	projectId: action.payload.projectId,
	// 	modelVersions: action.payload.modelVersions,
	// 	modelVersionMap,
	// });

	for (let modelVersion of action.payload.modelVersions) {
		if (!modelVersionMap[action.payload.projectId])
			modelVersionMap[action.payload.projectId] = {};
		modelVersionMap[action.payload.projectId][modelVersion.version] =
			modelVersion;
	}

	return {
		...state,
		modelVersionMap,
	};
};
