import { Box, Skeleton } from "@mui/material";
import {
  gridColumnPositionsSelector,
  gridColumnsTotalWidthSelector,
  useGridApiContext,
} from "@mui/x-data-grid";
import { GridStateColDef } from "@mui/x-data-grid/internals";
import { useMemo } from "react";

import { SkeletonCell } from "./SkeletonDataGrid.style";

const useGetColumnsCount = () => {
  const apiRef = useGridApiContext();
  const totalWidth = gridColumnsTotalWidthSelector(apiRef);
  const positions = gridColumnPositionsSelector(apiRef);
  const inViewportCount = positions.filter(
    (value) => value <= totalWidth
  ).length;

  return apiRef.current.getVisibleColumns().slice(0, inViewportCount);
};

const useChildrenElement = (
  columns: GridStateColDef[],
  skeletonRowsCount: number
) => {
  return useMemo(() => {
    const array: React.ReactNode[] = [];
    for (let i = 0; i < skeletonRowsCount; i += 1) {
      for (let j = 0; j < columns.length; j++) {
        const uniqueKey = `skeleton-${i}-${columns[j].field}`;
        array.push(
          <SkeletonCell
            key={uniqueKey}
            sx={{ justifyContent: "center" }}
          >
            <Skeleton
              sx={{ mx: 1 }}
              width="80%"
            />
          </SkeletonCell>
        );
      }
      const rowKey = `skeleton-row-${i}`;
      array.push(<SkeletonCell key={rowKey} />);
    }
    return array;
  }, [columns, skeletonRowsCount]) as React.ReactNode[];
};

export function SkeletonLoadingOverlay() {
  const skeletonRowsCount = 2;
  const apiRef = useGridApiContext();
  const rowHeight = apiRef.current.unstable_getRowHeight(
    apiRef.current.getRowIdFromRowIndex(0)
  );
  const columns = useGetColumnsCount();
  const gridTemplateColumns = `${columns
    .map((column) => `${column.computedWidth}px`)
    .join(" ")} 1fr`;

  return (
    <Box
      style={{
        display: "grid",
        gridTemplateColumns,
        gridAutoRows: rowHeight,
      }}
    >
      {useChildrenElement(columns, skeletonRowsCount)}
    </Box>
  );
}
