import { useEffect, useMemo, useRef } from "react";
import * as d3 from "d3";

const MARGIN = { top: 30, right: 30, bottom: 50, left: 50 };
const ROTATED_X_AXIS_EXTRA_MARGIN = { bottom: 25, right: 25 };

type Group = any;

type StackedBarplotProps = {
  width: number;
  height: number;
  data: Group[];
  allSubgroups: string[];
  rotateXAxis?: boolean;
};

export const StackedBarplot = ({
  width,
  height,
  data,
  allSubgroups,
  rotateXAxis
}: StackedBarplotProps) => {

  // bounds = area inside the graph axis = calculated by substracting the margins
  const axesRef = useRef(null);
  const boundsWidth = width - (rotateXAxis ? MARGIN.right + ROTATED_X_AXIS_EXTRA_MARGIN.right : MARGIN.right) - MARGIN.left;
  const boundsHeight = height - MARGIN.top - (rotateXAxis ? MARGIN.bottom + ROTATED_X_AXIS_EXTRA_MARGIN.bottom : MARGIN.bottom);

  const allGroups = data.map((d) => String(d.x));

  // Data Wrangling: stack the data
  const stackSeries = d3.stack().keys(allSubgroups).order(d3.stackOrderNone);
  //.offset(d3.stackOffsetNone);
  const series = stackSeries(data);

  // Y axis
  const subgroupSums = data.map((d) => {
    return allSubgroups
      .map((subgroup) => +d[subgroup])
      .reduce((sum, current) => sum + current, 0)
  })
  const max = Math.ceil(Math.max(...subgroupSums)/10)*10; // round to next 10
  const yScale = useMemo(() => {
    return d3
      .scaleLinear()
      .domain([0, max || 0])
      .range([boundsHeight, 0]);
  // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [data, height]);

  // X axis
  const xScale = useMemo(() => {
    return d3
      .scaleBand<string>()
      .domain(allGroups)
      .range([0, boundsWidth])
      .padding(0.05);
  // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [data, width]);

  // Color Scale
  var colorScale = d3
    .scaleOrdinal<string>()
    .domain(allSubgroups)
    .range(["#e0ac2b", "#e85252", "#6689c6", "#9a6fb0", "#a53253"]);

  // Render the X and Y axis using d3.js, not react
  useEffect(() => {
    const svgElement = d3.select(axesRef.current);
    svgElement.selectAll("*").remove();
    const xAxisGenerator = d3.axisBottom(xScale);
    svgElement
      .append("g")
      .attr("transform", "translate(0," + boundsHeight + ")")
      .call(xAxisGenerator);
    if(rotateXAxis) {
      svgElement
          .selectAll("text")
          .attr("text-anchor", "start")
          .attr("transform", "rotate(25)")
          .attr("dy", "+0.9em")
          .attr("dx", "+0em");
    }

    const yAxisGenerator = d3.axisLeft(yScale);
    svgElement.append("g").call(yAxisGenerator);
  }, [xScale, yScale, boundsHeight, rotateXAxis]);

  const rectangles = series.map((subgroup, i) => {
    return (
      <g key={i}>
        {subgroup.map((group, j) => {
          return (
            <rect
              key={j}
              x={xScale(group.data.x.toString())}
              y={yScale(group[1])}
              height={yScale(group[0]) - yScale(group[1])}
              width={xScale.bandwidth()}
              fill={colorScale(subgroup.key)}
              opacity={0.9}
            ></rect>
          );
        })}
      </g>
    );
  });

  return (
    <div>
      <svg width={width} height={height}>
        <g
          width={boundsWidth}
          height={boundsHeight}
          transform={`translate(${[MARGIN.left, MARGIN.top].join(",")})`}
        >
          {rectangles}
        </g>
        <g
          width={boundsWidth}
          height={boundsHeight}
          ref={axesRef}
          transform={`translate(${[MARGIN.left, MARGIN.top].join(",")})`}
        />
      </svg>
    </div>
  );
};
