import { Fragment, useMemo, useState } from 'react'

import { Datum, ResponsiveLine, Serie } from '@nivo/line'
import { ScaleLinear, Area, Line } from 'd3'
import { area, line, curveBasis } from 'd3-shape'
import { format } from 'date-fns'
import { Patient } from 'types/graphql'

import StackView from 'src/components/atoms/StackView'
import Typography from 'src/components/atoms/Typography/Typography'
import { toDecimal } from 'src/lib/formatters'

import { chartColors } from './ChartSection'

export type GrowthChartProps = {
  patient: Patient
  showCorrectedAge: boolean
  chartData: Serie[]
  xAxisLabel: string
  yAxisLabel: string
  xMin: number | 'auto'
  xMax: number | 'auto'
  yMin?: number | 'auto'
  yMax?: number | 'auto'
  yAxisUnit: string
  lineColors?: string[]
  areaLayers?: AreaLayer[]
  percentileLabels?: PercentileLabel[]
  percentileLabelDistance?: number
  todayLineX?: number
  isInfantData: boolean
  isWeightForHeightData: boolean
  hasConnectedPoints: boolean
  growthChartHeight: string
  onGrowthChartFinishedRendering?: () => void
}

type Scales = {
  xScale: ScaleLinear<number, number>
  yScale: ScaleLinear<number, number>
}

type AreaDatum = {
  x: number
  y0: number
  y1: number
}[]

type AreaLayerProps = {
  bottom: Datum[]
  top: Datum[]
} & AreaLayer &
  Scales

type AreaLayer = {
  bottom:
    | 'bottomAxis'
    | 'bottomLine'
    | 'p5'
    | 'p10'
    | 'p25'
    | 'p50'
    | 'p75'
    | 'p85'
    | 'p90'
    | 'p95'
    | 'topLine'
    | 'topAxis'
  top:
    | 'bottomAxis'
    | 'bottomLine'
    | 'p5'
    | 'p10'
    | 'p25'
    | 'p50'
    | 'p75'
    | 'p85'
    | 'p90'
    | 'p95'
    | 'topLine'
    | 'topAxis'
  color?: string
  fillOpacity?: number
}

type PercentileLabelProps = PercentileLabel & Scales

type PercentileLabel = {
  dataKey: string
  alternateDataKey?: string
  index?: number
  numLabels?: number
  circleFill?: string
  circleStroke?: string
  labelFill?: string
}

type PlottedPointsProps = {
  data: Datum[]
  color: string
  symbol?: 'square' | 'circle'
  testId: string
} & Scales

const GrowthChart = ({
  patient,
  chartData,
  showCorrectedAge,
  xAxisLabel,
  yAxisLabel,
  xMin,
  xMax,
  yMin,
  yMax,
  yAxisUnit,
  lineColors,
  areaLayers,
  percentileLabels,
  percentileLabelDistance,
  todayLineX,
  isInfantData,
  isWeightForHeightData,
  hasConnectedPoints,
  growthChartHeight,
  onGrowthChartFinishedRendering,
}: GrowthChartProps) => {
  const [initiallyLoaded, setInitiallyLoaded] = useState<boolean>(false)
  const {
    transformedChartData,
    patientData,
    correctedPatientData,
    transformedTodayLineX,
    lowestYValue,
    highestYValue,
  } = useMemo(() => {
    let transformedChartData = chartData

    // If we are not looking at infant data or weight for height data, divide months by 12 to get years for x axis
    if (!isInfantData && !isWeightForHeightData) {
      transformedChartData = transformedChartData.map((obj) => {
        return {
          id: obj.id,
          data: obj.data.map((d) => {
            return {
              ...d,
              x: Number(d.x) / 12,
            }
          }),
        }
      })
    }

    // If the xMax and xMin are not auto, filter out data outside of those bounds
    if (xMax !== 'auto' && xMin !== 'auto') {
      transformedChartData = transformedChartData.map((obj) => {
        return {
          id: obj.id,
          data: obj.data.filter(
            (d) => Number(d.x) >= xMin && Number(d.x) <= xMax
          ),
        }
      })
    }

    // Find our patient data to plot
    const patientData = transformedChartData.find(
      (x) => x.id === 'patient'
    )?.data

    // Find our corrected patient data to plot
    const correctedPatientData = transformedChartData.find(
      (x) => x.id === 'correctedPatient'
    )?.data

    let transformedTodayLineX
    if (!isInfantData) {
      transformedTodayLineX = todayLineX / 12
    } else {
      transformedTodayLineX = todayLineX
    }

    const lowestYValue = transformedChartData.reduce((acc1, cd) => {
      return Math.min(
        acc1,
        cd.data.reduce((acc2, val) => {
          return Math.min(acc2, Number(val.y))
        }, Infinity)
      )
    }, Infinity)

    const highestYValue =
      transformedChartData.reduce((acc1, cd) => {
        return Math.max(
          acc1,
          cd.data.reduce((acc2, val) => {
            return Math.max(acc2, Number(val.y))
          }, -Infinity)
        )
      }, -Infinity) + 5

    return {
      transformedChartData,
      patientData,
      correctedPatientData,
      transformedTodayLineX,
      lowestYValue,
      highestYValue,
    }
  }, [chartData, xMax, xMin, todayLineX, isInfantData, isWeightForHeightData])

  const AreaLayer = ({
    bottom,
    top,
    color,
    fillOpacity = 0.5,
    xScale,
    yScale,
  }: AreaLayerProps) => {
    const data = bottom.map((d, i) => {
      return {
        x: d.x,
        y0: d.y,
        y1: top[i]?.y,
      }
    })

    const areaGenerator: Area<Datum> = area<AreaDatum>()
      .x((d) => xScale(d['x']))
      .y0((d) => yScale(d['y0']))
      .y1((d) => yScale(d['y1']))
      .curve((d) => curveBasis(d))

    return (
      <path d={areaGenerator(data)} fill={color} fillOpacity={fillOpacity} />
    )
  }

  const PercentileAreaLayers = ({ xScale, yScale }) => {
    const lineData = {
      bottomLine: transformedChartData.find(
        (d) => d.id === 'P2' || d.id === 'P3'
      )?.data,
      p5: transformedChartData.find((d) => d.id === 'P5')?.data,
      p10: transformedChartData.find((d) => d.id === 'P10')?.data,
      p25: transformedChartData.find((d) => d.id === 'P25')?.data,
      p75: transformedChartData.find((d) => d.id === 'P75')?.data,
      p85: transformedChartData.find((d) => d.id === 'P85')?.data,
      p90: transformedChartData.find((d) => d.id === 'P90')?.data,
      p95: transformedChartData.find((d) => d.id === 'P95')?.data,
      topLine: transformedChartData.find(
        (d) => d.id === 'P98' || d.id === 'P97'
      )?.data,
      bottomAxis: null,
      topAxis: null,
    }

    lineData.bottomAxis = lineData.bottomLine.map((d) => {
      return { x: d.x, y: yMin === 'auto' ? lowestYValue : yMin }
    })

    lineData.topAxis = lineData.topLine.map((d) => {
      return { x: d.x, y: yMax === 'auto' ? highestYValue : yMax }
    })

    return (
      <>
        {areaLayers.map((a) => (
          <AreaLayer
            key={`area-layer.top-${a.top}.bottom.${a.bottom}`}
            bottom={lineData[a.bottom]}
            top={lineData[a.top]}
            color={a.color}
            fillOpacity={a.fillOpacity}
            xScale={xScale}
            yScale={yScale}
          />
        ))}
      </>
    )
  }

  const PercentileLabel = ({
    xScale,
    yScale,
    dataKey,
    alternateDataKey = null,
    index,
    numLabels,
    circleFill = chartColors.yellowLight,
    circleStroke = chartColors.orange,
    labelFill = chartColors.orange,
  }: PercentileLabelProps) => {
    const data = transformedChartData.find(
      (d) => d.id === dataKey || d.id === alternateDataKey
    )?.data
    const middleIndex = Math.floor(data?.length / 2)

    // CDC data starts/ends with P3/P97 instead of P2/P98
    let label
    if (transformedChartData.find((d) => d.id === alternateDataKey)) {
      label = alternateDataKey.split('P')[1]
    } else if (transformedChartData.find((d) => d.id === dataKey)) {
      label = dataKey.split('P')[1]
    } else {
      return null
    }

    const labelStart = middleIndex - (numLabels / 2) * percentileLabelDistance

    const labelPosition = labelStart + index * percentileLabelDistance

    return (
      <>
        {data[labelPosition] && (
          <g>
            <circle
              cx={xScale(Number(data[labelPosition].x))}
              cy={yScale(Number(data[labelPosition].y))}
              r={8}
              fill={circleFill}
              stroke={circleStroke}
              strokeWidth="2"
            />
            <text
              x={xScale(Number(data[labelPosition].x))}
              y={yScale(Number(data[labelPosition].y))}
              textAnchor="middle"
              fill={labelFill}
              dy="3px"
              fontSize="8px"
              fontWeight="bold"
            >
              {label}
            </text>
          </g>
        )}
      </>
    )
  }

  const PercentileLabelLayers = ({ xScale, yScale }) => {
    return (
      <>
        {percentileLabels.map((label, index) => (
          <PercentileLabel
            key={label.dataKey}
            index={index}
            numLabels={percentileLabels.length}
            xScale={xScale}
            yScale={yScale}
            dataKey={label.dataKey}
            alternateDataKey={label.alternateDataKey}
            circleFill={label.circleFill}
            circleStroke={label.circleStroke}
            labelFill={label.labelFill}
          />
        ))}
      </>
    )
  }

  const PlottedPoints = ({
    xScale,
    yScale,
    data,
    color,
    symbol = 'circle',
    testId,
  }: PlottedPointsProps) => {
    const lineGenerator: Line<Datum> = line<Datum>()
      .x((d) => xScale(Number(d['x'])))
      .y((d) => yScale(Number(d['y'])))
    return (
      <g data-testid={testId}>
        {hasConnectedPoints && (
          <path
            d={lineGenerator(data)}
            fill="none"
            stroke={color}
            strokeWidth="3"
          />
        )}
        {data.map((d, i) => (
          <Fragment key={i}>
            {symbol === 'circle' && (
              <circle
                key={`point.x.${d.x}.y.${d.y}.index.${i}`}
                cx={xScale(Number(d.x))}
                cy={yScale(Number(d.y))}
                r={4}
                fill={color}
                stroke={color}
              />
            )}
            {symbol === 'square' && (
              <rect
                key={`point.x.${d.x}.y.${d.y}.index.${i}`}
                x={xScale(Number(d.x)) - 4}
                y={yScale(Number(d.y)) - 4}
                width="8"
                height="8"
                fill={color}
                stroke={color}
              />
            )}
          </Fragment>
        ))}
      </g>
    )
  }

  const PatientLine = ({ xScale, yScale }) => {
    if (patientData) {
      return (
        <PlottedPoints
          xScale={xScale}
          yScale={yScale}
          data={patientData}
          color={chartColors.black}
          testId="patient-plotted-points"
        />
      )
    }
  }

  const CorrectedPatientLine = ({ xScale, yScale }) => {
    if (showCorrectedAge && correctedPatientData) {
      return (
        <PlottedPoints
          xScale={xScale}
          yScale={yScale}
          data={correctedPatientData}
          color={chartColors.blue}
          symbol="square"
          testId="corrected-plotted-points"
        />
      )
    }
  }

  const TodayLine = ({ xScale, yScale }) => {
    if (!transformedTodayLineX || transformedTodayLineX < xMin) {
      return null
    }

    const lineGenerator: Line<Datum> = line<Datum>()
      .x((d) => xScale(d['x']))
      .y((d) => yScale(d['y']))

    const data = [
      { x: transformedTodayLineX, y: highestYValue },
      { x: transformedTodayLineX, y: lowestYValue },
    ]
    return (
      <g>
        <path
          d={lineGenerator(data)}
          fill="none"
          stroke={chartColors.black}
          style={{ strokeWidth: 1, strokeDasharray: '6,6' }}
          data-testid="today-line"
        />
        <text
          x={xScale(transformedTodayLineX)}
          y={yScale(highestYValue)}
          textAnchor="end"
          dx="-4px"
          dy="-4px"
          fontSize="10px"
          data-testid="today-line-text"
        >
          {format(new Date(), 'MMM d, yyyy')}
        </text>
      </g>
    )
  }

  const TooltipSlice = ({ sliceData, color, label, symbol }) => {
    return (
      <StackView justifyContent="center" className="w-30">
        <StackView
          direction="row"
          justifyContent="center"
          alignItems="center"
          space={50}
          className="pb-2"
        >
          <div
            className={`h-3 w-3 ${symbol === 'circle' ? 'rounded-lg' : ''}`}
            style={{
              backgroundColor: color,
            }}
          ></div>
          <Typography>{label}</Typography>
        </StackView>

        {sliceData.effectiveAt && (
          <Typography className="text-center" fontWeight="bold">
            {format(new Date(sliceData.effectiveAt), 'MM-dd-yyyy')}
          </Typography>
        )}
        {sliceData.value && (
          <Typography className="text-center" fontWeight="bold">
            {toDecimal(sliceData.value.value)} {sliceData.value.unit}
          </Typography>
        )}

        {sliceData.height && (
          <Typography className="text-center" fontWeight="bold">
            {toDecimal(sliceData.height.value)} {sliceData.height.unit}
          </Typography>
        )}

        {sliceData.imperial && (
          <Typography className="text-center" fontWeight="bold">
            {sliceData.imperial.display}
          </Typography>
        )}

        {sliceData.heightImperial && (
          <Typography className="text-center" fontWeight="bold">
            {sliceData.heightImperial.display}
          </Typography>
        )}

        {sliceData.percentile && (
          <Typography className="text-center" fontWeight="bold">
            {toDecimal(sliceData.percentile)}%
          </Typography>
        )}
      </StackView>
    )
  }

  const FinalLayer = () => {
    if (!initiallyLoaded) {
      onGrowthChartFinishedRendering?.()
      setInitiallyLoaded(true)
    }
    return null
  }

  return (
    <div
      className="my-core-space-50 bg-base-color-bg-subtle"
      data-testid="growth-chart"
      style={{ height: growthChartHeight }}
    >
      <ResponsiveLine
        data={transformedChartData}
        margin={{ top: 15, right: 10, bottom: 50, left: 50 }}
        xScale={{ type: 'linear', min: xMin, max: xMax }}
        yScale={{
          type: 'linear',
          min: yMin,
          max: yMax === 'auto' ? highestYValue : yMax,
        }}
        axisBottom={{
          tickSize: 5,
          tickPadding: 5,
          tickRotation: 0,
          legend: xAxisLabel,
          legendOffset: 36,
          legendPosition: 'middle',
        }}
        axisLeft={{
          tickSize: 5,
          tickPadding: 5,
          tickRotation: 0,
          legend: `${yAxisLabel} (${yAxisUnit})`,
          legendOffset: -40,
          legendPosition: 'middle',
        }}
        curve="basis"
        lineWidth={2}
        colors={lineColors}
        enablePoints={false}
        enableSlices="x"
        layers={[
          'grid',
          'markers',
          'axes',
          'areas',
          PercentileAreaLayers,
          'lines',
          'points',
          PercentileLabelLayers,
          TodayLine,
          PatientLine,
          CorrectedPatientLine,
          'crosshair',
          'slices',
          'mesh',
          'legends',
          FinalLayer,
        ]}
        enableGridX={true}
        enableGridY={true}
        animate={false}
        sliceTooltip={({ slice }) => {
          const patientSlice = slice.points.find(
            (x) => x.serieId === 'patient'
          )?.data
          const correctedPatientSlice = slice.points.find(
            (x) => x.serieId === 'correctedPatient'
          )?.data
          const showPatientSlice =
            patientSlice && slice.points.some((p) => p.serieId === 'patient')
          const showCorrectedPatientSlice =
            showCorrectedAge &&
            correctedPatientSlice &&
            slice.points.some((p) => p.serieId === 'correctedPatient')
          return (
            <>
              {(showPatientSlice || showCorrectedPatientSlice) && (
                <div className="rounded-md bg-white p-5 shadow-md">
                  <StackView direction="row" space={100}>
                    {showPatientSlice && (
                      <TooltipSlice
                        sliceData={patientSlice}
                        color={chartColors.black}
                        symbol="circle"
                        label={
                          showCorrectedAge ? 'Chronological' : patient.givenName
                        }
                      />
                    )}
                    {showCorrectedPatientSlice && (
                      <TooltipSlice
                        sliceData={correctedPatientSlice}
                        color={chartColors.blue}
                        symbol="square"
                        label="Corrected"
                      />
                    )}
                  </StackView>
                </div>
              )}
            </>
          )
        }}
      />
    </div>
  )
}

export default GrowthChart
