import {useEffect, useRef, useState} from "react";
import * as d3 from "d3";
import {Legend} from "../../../shared/utils/Legend";
import {getPlot4Data} from "../../../slices/allServiceSlice";
import {useDispatch, useSelector} from "react-redux";
import {useConfig} from "../../../shared/config/ConfigContext";

const Plot4 = ({width, height, projectId, pipelineId}) => {
    const dispatch = useDispatch();
    const { apiBaseUrl } = useConfig();
    const allService = useSelector(state => state.allService);
    const svgRef = useRef();
    const correctionScoreRef = useRef();
    const [timeData, setTimeData] = useState([]);
    const [timepointsList, setTimepointsList] = useState([]);

    const processData = (input) => {
        const lines = input.trim().split('\n');

        const [...rows] = lines;

        let timeData = [];

        // Process each row of data
        rows.forEach(row => {
            const [rowTime, colTime, val] = row.split(',');

            timeData.push({
                row: rowTime,
                col: colTime,
                val: parseFloat(val)
            });
        });

        return timeData;
    }

    useEffect(() => {
        dispatch(getPlot4Data({apiBaseUrl, projectId, pipelineId}))
    }, [dispatch, apiBaseUrl])
    useEffect(() => {
        setTimeData(processData(allService.plots.plot4.content))
    }, [allService.plots.plot4.content])
    useEffect(() => {
        setTimepointsList([
            ...new Set(
                d3.map(timeData, function (d) {
                    return d.row;
                })
            )
        ]);
    }, [timeData])
    useEffect(() => {
        // Define margins
        const margin = { top: 0, right: 200, bottom: 200, left: 0 };

        // Create an SVG container
        const svg = d3
            .select(svgRef.current)
            .attr('width', width + margin.left + margin.right)
            .attr('height', height + margin.top + margin.bottom)
            .append('g')
            .attr('transform', 'translate(' + margin.left + ',' + margin.top + ')');

        // Define the scales
        const x = d3
            .scaleBand()
            .range([0, width])
            .domain(timepointsList)
            .padding(0.05);

        const y = d3
            .scaleBand()
            .range([height, 0])
            .domain(timepointsList)
            .padding(0.05);

        // Define the color scale
        const myColor = d3
            .scaleSequential()
            .interpolator(d3.interpolateViridis)
            .domain([-1, 1]);

        // Create a tooltip
        const tooltip = d3
            .select('body')
            .append('div')
            .style('opacity', 0)
            .attr('class', 'tooltip')
            .style('background-color', 'black')
            .style('border', 'solid')
            .style('border-width', '2px')
            .style('border-radius', '5px')
            .style('padding', '5px');

        // Mouseover, mousemove, mouseleave functions
        const mouseover = function () {
            tooltip.style('opacity', 1);
            d3.select(this).style('stroke', 'black').style('opacity', 1);
        };

        const mousemove = function (event, d) {
            tooltip
                .html(d.val)
                .style('top', event.clientY - 10 + 'px')
                .style('left', event.clientX + 10 + 'px');
        };

        const mouseleave = function () {
            tooltip.style('opacity', 0);
            d3.select(this).style('stroke', 'none').style('opacity', 0.8);
        };

        // Append the squares for the heatmap
        svg
            .selectAll()
            .data(timeData, (d) => d.row + ':' + d.col)
            .enter()
            .append('rect')
            .attr('x', (d) => x(d.row))
            .attr('y', (d) => y(d.col))
            .attr('rx', 4)
            .attr('ry', 4)
            .attr('width', x.bandwidth())
            .attr('height', y.bandwidth())
            .style('fill', (d) => myColor(d.val))
            .style('stroke-width', 4)
            .style('stroke', 'none')
            .style('opacity', 0.8)
            .on('mouseover', mouseover)
            .on('mousemove', mousemove)
            .on('mouseleave', mouseleave)
            .append("title").text(
            (d) =>
                "Row :" +
                d.col +
                "\n" +
                "Col: " +
                d.row +
                "\n\n" +
                "Correlation : " +
                Math.round(d.val * 100000) / 100000
        );


        return () => {
            // Cleanup: remove tooltip on unmount
            tooltip.remove();
        };
    }, [timepointsList, timeData, width, height]);
    useEffect(() => {
        if (correctionScoreRef.current) {
            const legendNode = Legend(d3.scaleSequential([-1, 1], d3.interpolateViridis), {
                title: "Correlation Score"
            })
            correctionScoreRef.current.appendChild(legendNode);
        }
        return () => {
            if (correctionScoreRef.current && correctionScoreRef.current.lastChild) {
                correctionScoreRef.current.removeChild(correctionScoreRef.current.lastChild);
            }
        };
    }, [timepointsList, timeData, width, height]);
    return (
        <div>
            <div ref={correctionScoreRef} className={"pl-4 py-4"} />
            <svg ref={svgRef} width={width + 50} height={height + 80} />
        </div>
    );
}

export default Plot4;