import EditIcon from '@mui/icons-material/Edit'
import { Stack, TextField, Typography } from '@mui/material'
import { useEffect } from 'react'
import { useForm } from 'react-hook-form'
import { useDispatch } from 'react-redux'
import { useLocation, useNavigate } from 'react-router'
import Loader from '../Components/UiComponents/Loader'
import { COMPLETIONS_MODEL_PROVIDERS } from '../Configs/JobConstants'
import { PAGE_ROUTES } from '../Configs/Routes'
import { setErrorMessage, setIsError } from '../DataStore/errorSlice'
import { CustomFinetuneForm } from '../NewComponents/FineTuning/CustomFinetuneForm'
import { FinetuneTemplateForm } from '../NewComponents/FineTuning/FinetuneTemplateForm'
import { useCreateFineTuneMutation } from '../Services/fineTuneApi'
import { useGetFineTuneTemplatesQuery } from '../Services/templatesApi'
import { color } from '../Styles/Color'
import { getTrimmedData } from '../Utils/trimmer'

const initialValue = {
  "ft_type": "CLM",
  "use_spot": true,
  "model": "",
  "base_model": "",
  "data_path": "",
  "job_name": "",
  "user_dataset": "",
  "artifacts_storage": "",
  "cloud_providers": [],
  "use_recipes": false,
  "autotrain_params": {
    "model": "",
    "job_name": "",
    "project_name": "ScaleGen Project",
    "data_path": "",
    "push_to_hub": false,
    "repo_id": "",
    "username": "",
    "hf_token": "",
    "wandb_key": "",
    "comet_ml_key": "",
    "train_split": "train",
    "train_subset": "",
    "valid_split": "",
    "valid_subset": "",
    "add_eos_token": true,
    "block_size": -1,
    "model_max_length": 0,
    "padding": "left",
    "trainer": "default",
    "use_flash_attention_2": false,
    "log": "none",
    "disable_gradient_checkpointing": false,
    "logging_steps": -1,
    "eval_strategy": "epoch",
    "save_total_limit": 1,
    "save_strategy": "steps",
    "auto_find_batch_size": false,
    "mixed_precision": "fp16",
    "lr": 0.00003,
    "epochs": 1,
    "batch_size": 2,
    "warmup_ratio": 0.1,
    "gradient_accumulation_steps": 1,
    "optimizer": "adamw_torch",
    "lr_scheduler_type": "linear",
    "weight_decay": 0,
    "max_grad_norm": 1,
    "seed": 42,
    "save_steps": 20,
    "eval_steps": 0,
    "load_best_model_at_end": true,
    "resume_from_checkpoint": "",
    "user_checkpoint_dir": "",
    "neftune_noise_alpha": 0,
    "use_deepspeed": "",
    "apply_chat_template": "zephyr",
    "torch_dtype": "auto",
    "use_torch_compile": true,
    "quantization": "nf4",
    "double_quantization": false,
    "use_peft": "lora",
    "lora_r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "init_lora_weights": "gaussian",
    "use_rslora": false,
    "adalora_init_r": 12,
    "adalora_target_r": 8,
    "llama_adapter_len": 128,
    "llama_adapter_layers": 8,
    "target_modules": "",
    "merge_adapter": false,
    "model_ref": "",
    "dpo_beta": 0.1,
    "max_prompt_length": 128,
    "max_completion_length": 0,
    "prompt_text_column": "",
    "text_column": "text",
    "rejected_text_column": "",
    "use_unsloth": false
  },
  "gpu_type": "A100_PCIE",
  "gpu_count": 1
}

export const LaunchFineTuning = () => {

  const model_type = useLocation().state.model || COMPLETIONS_MODEL_PROVIDERS[3]

  const nav = useNavigate()
  const dispatch = useDispatch()

  const [submit, { isLoading: isSubmitting, isSuccess }] = useCreateFineTuneMutation()

  const { data: models, isLoading: isFetchingModels, isSuccess: isModelsFetched } = useGetFineTuneTemplatesQuery(
    null,
    {
      skip: model_type === COMPLETIONS_MODEL_PROVIDERS[2].name
    }
  )

  const { formState: { errors }, watch, setValue, setError, clearErrors, reset, handleSubmit } =
    useForm({
      defaultValues: initialValue
    })

  const handleBack = () => {
    nav(-1)
  }

  const handleValidation = () => {
    clearErrors()

    if (watch('job_name').length === 0) {
      dispatch(setIsError(true))
      dispatch(setErrorMessage("Job Name is required"))
      setError('job_name', { type: 'custom', message: 'Job Name is required' })
      return
    }

    if (watch('data_path').length === 0) {
      setError('data_path', { type: 'custom', message: 'This field is required' })
      return
    }

    if (watch('autotrain_params.push_to_hub') && watch('autotrain_params.repo_id').length === 0) {
      setError('push_to_hub.repo_id', { type: 'custom', message: 'This field is required' })
      return
    }

    if (watch("autotrain_params.use_peft") === "lora" && (
      watch('autotrain_params.lora_alpha') === "0" ||
      watch('autotrain_params.lora_dropout') === "0" ||
      watch('autotrain_params.lora_r') === "0" ||
      watch('autotrain_params.lora_alpha') === null ||
      watch('autotrain_params.lora_dropout') === null ||
      watch('autotrain_params.lora_r') === null
    )) {
      dispatch(setIsError(true))
      dispatch(setErrorMessage("Disable lora toggle , If you don't want lora"))
      return
    }

    const config = getTrimmedData(watch())
    handleSubmit(submit(config))
  }

  useEffect(() => {
    if (isModelsFetched && models?.length > 0) {
      reset(models?.filter(m => m?.config?.model.includes(model_type))[0]?.config)
    }
  }, [isModelsFetched, model_type, models, reset])

  if (isFetchingModels || isSubmitting) {
    return <Stack height="70vh">
      <Loader />
    </Stack>
  }

  if (isSuccess) {
    nav(PAGE_ROUTES.fineTuning)
  }

  return (
    <Stack spacing={5}>
      <Stack direction="row" gap={2} alignItems={errors?.job_name ? "start" : "end"}>
        <Typography variant="h1">{model_type}</Typography>
        <TextField
          size='small'
          variant="standard"
          placeholder='Enter fine-tuning job name'
          sx={{
            "& .MuiInput-root:before": {
              border: 0
            },
            "& .MuiInput-root:after": {
              border: 0
            },
            "& .MuiInput-input:hover": {
              border: 0
            }
          }}
          InputProps={{
            style: {
              fontSize: "14px",
              borderRadius: "8px",
              color: color.primary,
            },
          }}
          value={watch('job_name')}
          onChange={e => {
            setValue('job_name', e.target.value)
            setValue("autotrain_params.job_name", e.target.value)
          }}
          // {...register("job_name", { required: { value: true, message: "This field is required" } })}
          error={errors?.job_name ? true : false}
          helperText={errors?.job_name?.message}
        />
        <EditIcon sx={{ color: color.primary, fontSize: "16px" }} />
      </Stack>
      {
        (model_type === COMPLETIONS_MODEL_PROVIDERS[2].name) ?
          <CustomFinetuneForm
            watch={watch} setValue={setValue} errors={errors}
            setError={setError} clearErrors={clearErrors}
            handleSubmit={handleValidation} reset={reset}
          /> :
          <FinetuneTemplateForm
            watch={watch} setValue={setValue}
            isBaseModel={true} handleBack={handleBack} backTitle="Change Model"
            modelList={
              models?.filter(m => m.config.model.includes(model_type))
            }
            handleSubmit={handleValidation}
            reset={reset}
            errors={errors}
          />
      }
    </Stack>
  )
}
