import ArrowForwardIosIcon from '@mui/icons-material/ArrowForwardIos';
import EditIcon from '@mui/icons-material/Edit';
import { Stack, TextField, Typography } from "@mui/material";
import { useEffect, useState } from "react";
import { useForm } from 'react-hook-form';
import { useDispatch } from 'react-redux';
import { useNavigate } from 'react-router';
import Configure from '../Components/TrainingViewComponents/FineTuning/Configure';
import { CustomGPUConfig } from '../Components/TrainingViewComponents/FineTuning/CustomGPUConfig';
import { DataSources } from '../Components/TrainingViewComponents/FineTuning/DataSources';
import CustomButton from '../Components/UiComponents/CustomButton';
import Loader from "../Components/UiComponents/Loader";
import { FINE_TUNE_DATASET_OPTIONS } from '../Configs/ConfigureNewJobConstants';
import { CONFIGURE_OPTIONS } from '../Configs/JobConstants';
import { PAGE_ROUTES } from '../Configs/Routes';
import { setErrorMessage, setIsError } from '../DataStore/errorSlice';
import { useGetArtifactStoragePathsQuery } from '../Services/artifactStorageApi';
import { useGetGpuTypesMutation, useGetInstanceTypesMutation, useGetUserCloudsQuery } from '../Services/cloudProviderApi';
import { useCreateFineTuneMutation } from '../Services/fineTuneApi';
import { useGetModelsQuery } from '../Services/inferenceApi';
import { useGetUserOnPremNodesQuery } from '../Services/nodeApi';
import { useGetVirtualMountsQuery } from "../Services/virtualMountsApi";
import { color } from '../Styles/Color';

const initialValue = {
  "use_spot": false,
  "wandb_key": "",
  "model": "",
  "data_path": "",
  "job_name": "",
  "user_dataset": "",
  "hf_token": "",
  "push_to_hub": false,
  "username": "",
  "repo_id": "",
  "project_name": "",
  "artifacts_storage": "",
  "cloud_providers": [],
  "autotrain_params": {
    "use_peft": "lora",
    "quantization": null,
    "mixed_precision": null,
    "disable_gradient_checkpointing": true,
    "use_flash_attention_2": true,
    "lora_r": "0",
    "lora_alpha": "0",
    "lora_dropout": "0",
    "lr": 0.00003,
    "batch_size": 1,
    "epochs": 1,
    "train_subset": "",
    "text_column": "",
    "gradient_accumulation": null,
    "model_max_length": null,
    "block_size": null,
    "torch_dtype": "auto",
    "seed": null
  },
  "gpu_type": "A10G",
  "gpu_count": 0
}

export const FineTuningLlamaMistral = () => {

  const dispatch = useDispatch()

  const { register, setError, formState: { errors }, watch, handleSubmit, setValue, reset, clearErrors } =
    useForm({ defaultValues: initialValue })

  const nav = useNavigate()
  const [typeSelected, setTypeSelected] = useState(FINE_TUNE_DATASET_OPTIONS[0])
  const [isLoRA, setIsLoRA] = useState(false)
  const [isQuantization, setIsQuantization] = useState(false)
  const [isMixedPrecision, setIsMixedPrecision] = useState(false)
  const [gpuNodes, setGpuNodes] = useState([])
  // const [gputype, setGputype] = useState("")
  const [allowOtherGPU, setAllowOtherGPU] = useState(false)
  const [cloudProvider, setCloudProvider] = useState([])
  const [region, setRegion] = useState([])
  const [regionOptions, setRegionOptions] = useState([])
  const [cloudBurst, setCloudBurst] = useState(false)
  const [gpuNodeType, setGPUNodeType] = useState([])
  // const [numberOfGPUs, setNumberOfGPUs] = useState(-1)
  const [instanceType, setInstancetype] = useState("")
  const [gpuPerNode, setGPUPerNode] = useState(-1)
  const [selectedConfig, setSelectedConfig] = useState(CONFIGURE_OPTIONS[0])


  const [submit, { isLoading: isSubmitting, isSuccess }] = useCreateFineTuneMutation()
  const { data: artifactStoragePaths, isLoading } = useGetArtifactStoragePathsQuery()
  const { data: userVMs, isLoading: isFetchingVMs } = useGetVirtualMountsQuery()
  const { data: gpuOptions, isLoading: isFechingNodes } = useGetUserOnPremNodesQuery()
  const { data: cloudProviderOptions, isLoading: isFetchingClouds, isSuccess: isCloudsFetched } = useGetUserCloudsQuery()
  const [getGPUTypes, { data: gpuTypeOptions, isLoading: isFetchingGPUs }] = useGetGpuTypesMutation()
  const [getInstanceTypes, { data: instanceTypeOptions, isLoading: isFetchingInstances }] = useGetInstanceTypesMutation()

  const { data: models, isLoading: isFetchingModels } = useGetModelsQuery()


  useEffect(() => {
    if (isCloudsFetched) {
      setCloudProvider(
        [...cloudProviderOptions
          ?.filter(d => ["AWS", "AZURE", "GCP"].includes(d.cloud_provider))
          ?.map(cp => cp.cloud_provider),
          "SCALEGENAI"
        ] || []
      )
      setRegion(
        [...cloudProviderOptions
          ?.filter(d => ["AWS", "AZURE", "GCP"].includes(d.cloud_provider))
          ?.reduce(
            (accumulator, cp) => {
              return [...accumulator, ...cp.regions.map(region => `${cp.cloud_provider}:${region}`)]
            },
            [],
          ),
          "SCALEGENAI:US", "SCALEGENAI:ASIA", "SCALEGENAI:EU", "SCALEGENAI:CANADA"
        ]
      )
      setRegionOptions(
        [...cloudProviderOptions
          ?.filter(d => ["AWS", "AZURE", "GCP"].includes(d.cloud_provider))
          ?.reduce(
            (accumulator, cp) => {
              return [...accumulator, ...cp.regions.map(region => `${cp.cloud_provider}:${region}`)]
            },
            [],
          ),
          "SCALEGENAI:US", "SCALEGENAI:ASIA", "SCALEGENAI:EU", "SCALEGENAI:CANADA"
        ]
      )
    }
  }, [cloudProviderOptions, isCloudsFetched, setValue])

  useEffect(() => {
    const timer = setTimeout(() => {
      if (cloudProvider.length > 0 && region.length > 0) {
        getGPUTypes({
          cloudProviders: cloudProvider,
          cloudRegions: region
        })
        getInstanceTypes({
          cloudProviders: cloudProvider,
          cloudRegions: region
        })
      }
    }, 3000)

    setValue('cloud_providers',
      cloudProvider.map(cp => {
        return {
          name: cp,
          regions: region?.filter(r => r.includes(cp)).map(r => r.split(":")[1])
        }
      })
    )

    return () => {
      clearTimeout(timer)
    };
  }, [cloudProvider, getGPUTypes, getInstanceTypes, region, setValue])

  const handleValidation = () => {
    clearErrors()

    if (watch('model').length === 0) {
      setError('model', { type: 'custom', message: 'This field is required' })
    }
    if (watch('data_path').length === 0) {
      setError('dataset', { type: 'custom', message: 'This field is required' })
    }

    if (watch('push_to_hub') && watch('username').length === 0) {
      setError('push_to_hub.username', { type: 'custom', message: 'This field is required' })
    }

    if (watch('push_to_hub') && watch('repo_id').length === 0) {
      setError('push_to_hub.repo_id', { type: 'custom', message: 'This field is required' })
    }

    if (isLoRA && (
      watch('autotrain_params.lora_alpha') === "0" ||
      watch('autotrain_params.lora_dropout') === "0" ||
      watch('autotrain_params.lora_r') === "0"
    )) {
      setError('autotrain_params', { type: 'custom', message: 'This field is required' })
      dispatch(setIsError(true))
      dispatch(setErrorMessage("Disable lora toggle , If you don't want lora"))
    }

    handleSubmit(handleStart)()
  }

  const handleStart = (data) => {
    const config = {
      ...data,
      "project_name": data.job_name
    }

    Object.keys(config).forEach(function (key, index) {
      if (this[key] === "") this[key] = null;
    }, config)

    submit(config)
  }

  if (isSubmitting || isLoading || isFetchingVMs || isFechingNodes || isFetchingClouds || isFetchingModels) {
    return <Stack height="80vh"><Loader /></Stack>
  }

  if (isSuccess) {
    reset({ data: initialValue })
    nav(PAGE_ROUTES.training)
  }

  return (
    <Stack gap={2} overflow="auto"
      sx={{
        '&::-webkit-scrollbar': {
          display: 'none',
        }
      }}
      py={1}
    >
      <Typography fontFamily="IBMPlexSansSemiBold" fontSize="25px" pb={2}>
        Fine-Tune (llama and mistral)
      </Typography>
      <Stack direction="row" gap={2} alignItems="center">
        <Typography fontSize="15px">Job Name</Typography>
        <TextField
          size='small'
          variant="standard"
          placeholder='Enter job name'
          sx={{
            fontFamily: "IBMPlexSansSemiBold",
          }}
          InputProps={{
            style: {
              fontFamily: "IBMPlexSansSemiBold", fontSize: "15px",
              color: color.primary
            },
          }}
          {
          ...register('job_name', { required: { value: true, message: "This field is required" } })
          }
          error={errors["job_name"] ? true : false}
          helperText={errors["job_name"]?.message}
        />
        <EditIcon sx={{ color: color.primary, fontSize: "16px" }} />
      </Stack>
      <DataSources
        artifactStoragePaths={artifactStoragePaths}
        typeSelected={typeSelected} setTypeSelected={setTypeSelected} vmOptions={userVMs}
        modelList={
          models
            ?.filter(model => model.type === "llm")
            ?.filter(model => model.model.includes("llama") || model.model.includes("mistral"))
            ?.map(model => model.model)
        }
        register={register} errors={errors} setValue={setValue} watch={watch}
      />
      <CustomGPUConfig
        gpuOptions={gpuOptions}
        gpuNodes={gpuNodes}
        setGpuNodes={setGpuNodes}
        // gputype={gputype} setGputype={setGputype} 
        gpuTypeOptions={gpuTypeOptions}
        allowOtherGPU={allowOtherGPU} setAllowOtherGPU={setAllowOtherGPU}
        cloudProvider={cloudProvider} setCloudProvider={setCloudProvider}
        region={region} setRegion={setRegion}
        cloudProviderOptions={
          [...cloudProviderOptions, { cloud_provider: "SCALEGENAI", regions: ["US", "ASIA", "EU", "CANADA"] }]
        }
        regionOptions={regionOptions} cloudBurst={cloudBurst} setCloudBurst={setCloudBurst}
        gpuNodeType={gpuNodeType} setGPUNodeType={setGPUNodeType} instanceType={instanceType}
        setInstancetype={setInstancetype}
        // numberOfGPUs={numberOfGPUs} setNumberOfGPUs={setNumberOfGPUs}
        instanceTypeOptions={instanceTypeOptions} gpuPerNode={gpuPerNode} setGPUPerNode={setGPUPerNode}
        selectedConfig={selectedConfig} setSelectedConfig={setSelectedConfig}
        isFetchingGPUs={isFetchingGPUs} isFetchingInstances={isFetchingInstances}
        watch={watch} setValue={setValue}
      />

      <Configure
        selectedConfig={selectedConfig}
        cloudBurst={cloudBurst}
        register={register} errors={errors} setValue={setValue} watch={watch}
        isLoRA={isLoRA} setIsLoRA={setIsLoRA}
        isQuantization={isQuantization} setIsQuantization={setIsQuantization}
        isMixedPrecision={isMixedPrecision} setIsMixedPrecision={setIsMixedPrecision}
      />
      <Stack alignItems="end">
        {/* <Stack direction="row" gap={2} width="70%"> */}
        {/* <SecondaryButton onClick={() => { }}>
            <Typography color={color.primary} fontSize="15px">Save as Template</Typography>
          </SecondaryButton>
          <SecondaryButton onClick={() => { }}>
            <Typography color={color.primary} fontSize="15px">Download as YAML Config</Typography>
          </SecondaryButton> */}
        <CustomButton width="30%" onClick={handleValidation}>
          <Stack fontSize="15px" direction="row" alignItems="center">
            Start
            <ArrowForwardIosIcon sx={{ fontSize: "15px" }} />
          </Stack>
        </CustomButton>
      </Stack>
      {/* </Stack > */}
    </Stack >
  )
}

