import {
  ExperimentLabel,
  EXPERIMENT_TASK,
  Model,
  PreTrainedExperimentConfig,
} from 'types';
import { twoDecimalsRound } from 'utils';

type MetricNameLabels = {
  task: number | string;
};

export const getMetricsExperimentConfig = (
  model: Model,
  config: PreTrainedExperimentConfig
) => {
  if (config.pre_trained_score_metric) {
    return {
      accuracy: {
        label: config.pre_trained_score_metric.toUpperCase(),
        number: config.pre_trained_score_value,
        tooltip: config.pre_trained_score_metric_desc,
      },
      inference: {
        label: 'AVG INFERENCE TIME',
        number: config.pre_trained_inference_time,
        tooltip: config.pre_trained_inference_time_desc,
      },
    };
  } else {
    return {
      accuracy: {
        label: 'ACCURACY SCORE',
        number:
          model.result?.accuracy && twoDecimalsRound(model.result.accuracy),
        tooltip:
          'This is the ratio between the number of correctly predicted examples and the total number of examples in the the validation set.',
      },
      inference: {
        label: 'AVG INFERENCE TIME',
        number:
          model.result?.avg_inference_time_in_sec &&
          twoDecimalsRound(model.result.avg_inference_time_in_sec),
        tooltip:
          'This is an approximation of the time required to use the model with a new input. This value results from measuring the average inference time by taking some randomly examples from the validation set.',
      },
    };
  }
};

const accuracyTextMap: ExperimentLabel = {
  [EXPERIMENT_TASK.IMAGE_OBJECT_DETECTION]: 'mAP SCORE',
};

export const getMetricNameLabels = ({ task }: MetricNameLabels) => ({
  accuracy: accuracyTextMap[task],
});

export const imageInitialConfig = {
  daug_rotation: true,
  daug_brightness: true,
  daug_zoom: true,
  daug_horizontal_flip: true,
  daug_vertical_flip: true,
  algo_logistic_regression: true,
  algo_linear_sgd: true,
  algo_ffnn: true,
  algo_svc_rbf: true,
  cnn_small: false,
  cnn_xception: true,
  cnn_inceptionv3: false,
  cnn_inception_resnetv2: true,
  cnn_mobilenetv2: true,
  metric_to_optimize: 'f1',
};

export const textInitialConfig = {
  only_ascii: true,
  do_lower: true,
  do_stopwords_removal: 'false',
  do_lemmatization: 'true',
  algo_linear_sgd: true,
  algo_ffnn: true,
  algo_svc_rbf: true,
  algo_logistic_regression: true,
  vectorization_bow: false,
  vectorization_tfidf: false,
  vectorization_fasttext: false,
  vectorization_distilbert: true,
  vectorization_bert: false,
  vectorization_xlmroberta: false,
  lid_model: false,
  metric_to_optimize: 'auto',
};

export const audioInitialConfig = {
  audio_type: 'env',
  algo_mla: true,
  vectorization_l3net: true,
  vectorization_vggish: true,
  daug_rotation: true,
  daug_brightness: true,
  daug_zoom: true,
  daug_horizontal_flip: true,
  daug_vertical_flip: true,
  algo_logistic_regression: true,
  algo_linear_sgd: true,
  algo_ffnn: true,
  algo_svc_rbf: true,
  cnn_small: false,
  cnn_xception: false,
  cnn_inceptionv3: false,
  cnn_inception_resnetv2: false,
  cnn_mobilenetv2: true,
  metric_to_optimize: 'auto',
};
