import { TrainingOutput } from "modules/types/training";
import { MLFlowRun } from "modules/types/training/ml-flow";
import { UserData } from "modules/types/user";

export const getUserFirstName = (user: UserData) => {
	if (user?.profile?.firstName) {
		//return firnstmae with first letter capital and rest lowercase
		return (
			user?.profile?.firstName.charAt(0).toUpperCase() +
			user?.profile?.firstName.slice(1).toLowerCase()
		);
	}
	if (user?.email) {
		return user?.email;
	}
	return user?.username;
};

export const getRunMetric = (
	run: MLFlowRun,
	label: string,
	formatType?: string
) => {
	let metric = run?.data?.metrics?.find((metric) => metric.key === label);
	//if metric value is a number, return version fixed to 3 decimal places
	if (typeof metric?.value === "number") {
		if (formatType === "percent")
			return (metric?.value * 100).toFixed(2) + " %";
		return metric?.value.toFixed(2);
	}
	if (metric?.value) return metric?.value;
	return "-";
};

export const getPrecisionMetricName = (
	trainingOutput: TrainingOutput | boolean
) => {
	const isMultiLabel =
		trainingOutput === true ||
		(typeof trainingOutput != "boolean" &&
			trainingOutput?.trainingJob?.trainingDump?.includes(
				'"multiLabel":true'
			))
			? true
			: false;
	return isMultiLabel
		? "test_optimal_thresholds_weighted_avg_precision"
		: "test_weighted_avg_precision";
};

export const getF1ScoreMetricName = (
	trainingOutput: TrainingOutput | boolean
) => {
	const isMultiLabel =
		trainingOutput === true ||
		(typeof trainingOutput != "boolean" &&
			trainingOutput?.trainingJob?.trainingDump?.includes(
				'"multiLabel":true'
			))
			? true
			: false;
	return isMultiLabel
		? "test_optimal_thresholds_weighted_avg_f1_score"
		: "test_weighted_avg_f1_score";
};

export const getRecallMetricName = (
	trainingOutput: TrainingOutput | boolean
) => {
	const isMultiLabel =
		trainingOutput === true ||
		(typeof trainingOutput != "boolean" &&
			trainingOutput?.trainingJob?.trainingDump?.includes(
				'"multiLabel":true'
			))
			? true
			: false;
	return isMultiLabel
		? "test_optimal_thresholds_weighted_avg_recall"
		: "test_weighted_avg_recall";
};

export const getSklearnMetrics = (
	trainingOutput: TrainingOutput,
	runType: "pattern" | "model" | "global" = "global"
) => {
	const isMultiLabel = trainingOutput?.trainingJob?.trainingDump?.includes(
		'"multiLabel":true'
	)
		? true
		: false;

	if (isMultiLabel) {
		switch (runType) {
			case "global":
				return {
					...trainingOutput?.evaluation?.by_sklearn?.global?.[
						"macro avg"
					],
					accuracy:
						trainingOutput?.evaluation?.by_sklearn?.global
							?.accuracy || 0,
				};
			case "pattern":
				return {
					...trainingOutput?.evaluation?.by_sklearn?.pattern?.[
						"macro avg"
					],
					accuracy:
						trainingOutput?.evaluation?.by_sklearn?.pattern
							?.accuracy || 0,
				};
			case "model":
				return {
					...trainingOutput?.evaluation?.by_sklearn?.model?.[
						"macro avg"
					],
					accuracy:
						trainingOutput?.evaluation?.by_sklearn?.model
							?.accuracy || 0,
				};
			default:
				return {
					...trainingOutput?.evaluation?.by_sklearn?.global?.[
						"macro avg"
					],
					accuracy:
						trainingOutput?.evaluation?.by_sklearn?.global
							?.accuracy || 0,
				};
		}
	} else {
		switch (runType) {
			case "global":
				return {
					...trainingOutput?.evaluation?.by_sklearn?.global?.[
						"weighted avg"
					],
					accuracy:
						trainingOutput?.evaluation?.by_sklearn?.global
							?.accuracy || 0,
				};
			case "pattern":
				return {
					...trainingOutput?.evaluation?.by_sklearn?.pattern?.[
						"weighted avg"
					],
					accuracy:
						trainingOutput?.evaluation?.by_sklearn?.pattern
							?.accuracy || 0,
				};
			case "model":
				return {
					...trainingOutput?.evaluation?.by_sklearn?.model?.[
						"weighted avg"
					],
					accuracy:
						trainingOutput?.evaluation?.by_sklearn?.model
							?.accuracy || 0,
				};
			default:
				return {
					...trainingOutput?.evaluation?.by_sklearn?.global?.[
						"weighted avg"
					],
					accuracy:
						trainingOutput?.evaluation?.by_sklearn?.global
							?.accuracy || 0,
				};
		}
	}
};
