import * as React from 'react'
import {
    flexRender,
    Row,
    type Table as TanstackTable,
} from '@tanstack/react-table'

import { cn } from '@2/lib/utils'
import {
    Table,
    TableBody,
    TableCell,
    TableHead,
    TableHeader,
    TableRow,
} from '@2/components/ui/table'
import LoadingSpinner from '@/Components/LoadingSpinner'

interface DataTableProps<TData> extends React.HTMLAttributes<HTMLDivElement> {
    table: TanstackTable<TData>
    floatingBar?: React.ReactNode | null
    /**
     * Whether to show the totals row
     */
    showTotals?: boolean
    loadingReportData?: boolean
    hideColumns?: string[]
}

export function DataTable<TData>({
    table,
    floatingBar = null,
    children,
    className,
    showTotals = false,
    loadingReportData = false,
    hideColumns = [],
    ...props
}: DataTableProps<TData>) {
    const totalWidth = table
        .getAllColumns()
        .filter(
            (column) =>
                column.getIsVisible() && !hideColumns.includes(column.id)
        )
        .reduce((acc, column) => acc + column.getSize(), 0)

    const getAllSelectedRows = React.useCallback((rows: Row<TData>[]) => {
        return rows.reduce((acc: Row<TData>[], row: Row<TData>) => {
            if (row.getIsSelected()) {
                acc.push(row)
            }

            // Check subrows regardless of parent selection state
            if (row.subRows?.length) {
                const selectedSubRows = getAllSelectedRows(row.subRows)
                acc.push(...selectedSubRows)
            }
            return acc
        }, [])
    }, [])

    // Calculate if we should show floating bar
    const showFloatingBar = React.useMemo(() => {
        const rows = table.getFilteredRowModel().rows
        const selectedRows = getAllSelectedRows(rows)
        return selectedRows.length > 0
    }, [table.getFilteredRowModel().rows, table.getState().rowSelection])

    const rows = table.getRowModel().rows
    return (
        <>
            <div className="sticky left-0">{children}</div>
            {loadingReportData ? (
                <LoadingSpinner />
            ) : (
                <Table
                    className={cn(
                        `border-b border-border min-w-full print:!w-full relative table-fixed text-xs`
                    )}
                    style={{
                        width: `${totalWidth}px`,
                    }}
                >
                    <TableHeader className="sticky top-[-1px] bg-primary">
                        {table.getHeaderGroups().map((headerGroup) => (
                            <TableRow key={headerGroup.id}>
                                {headerGroup.headers
                                    .filter(
                                        (header) =>
                                            !hideColumns.includes(
                                                header.column.id
                                            )
                                    )
                                    .map((header) => (
                                        <TableHead
                                            key={header.id}
                                            className={cn(
                                                'py-2',
                                                header.column.columnDef.meta
                                                    ?.className
                                            )}
                                            style={{
                                                width: `${Math.round((header.column.getSize() / totalWidth) * 100)}%`,
                                            }}
                                        >
                                            {header.isPlaceholder
                                                ? null
                                                : flexRender(
                                                      header.column.columnDef
                                                          .header,
                                                      header.getContext()
                                                  )}
                                        </TableHead>
                                    ))}
                            </TableRow>
                        ))}
                    </TableHeader>
                    <TableBody>
                        {showTotals && (
                            <TableRow
                                style={{
                                    ['--row-base' as string]: `hsl(39, 12%, 99%)`,
                                    ['--row-hover' as string]: `hsl(39, 12%, 95%)`,
                                }}
                                className="font-bold transition-colors duration-200 bg-[var(--row-base)] hover:bg-[var(--row-hover)]"
                            >
                                {rows[0]
                                    ?.getVisibleCells()
                                    .filter(
                                        (cell) =>
                                            !hideColumns.includes(
                                                cell.column.id
                                            )
                                    )
                                    .map((cell) => {
                                        const aggregationFn =
                                            cell.column.getAggregationFn()
                                        const getValue = () =>
                                            aggregationFn
                                                ? aggregationFn(
                                                      cell.column.id,
                                                      rows,
                                                      rows
                                                  )
                                                : null
                                        return (
                                            <TableCell
                                                key={'total-' + cell.column.id}
                                                className={cn(
                                                    'py-4',
                                                    cell.column.columnDef.meta
                                                        ?.className
                                                )}
                                                style={{
                                                    width: `${Math.round((cell.column.getSize() / totalWidth) * 100)}%`,
                                                }}
                                            >
                                                {flexRender(
                                                    [
                                                        'expand',
                                                        'select',
                                                        'label',
                                                    ].includes(
                                                        cell.column.columnDef.id
                                                    )
                                                        ? null
                                                        : cell.column.columnDef
                                                              .cell,
                                                    {
                                                        ...cell.getContext(),
                                                        getValue,
                                                    }
                                                )}
                                            </TableCell>
                                        )
                                    })}
                            </TableRow>
                        )}
                        {table.getRowModel().rows?.length ? (
                            table.getRowModel().rows.map((row) => {
                                return (
                                    <TableRow
                                        key={row.id}
                                        style={
                                            {
                                                '--row-base': `hsl(39, 12%, ${98 - row.depth * 2}%)`,
                                                '--row-hover': `hsl(39, 12%, ${94 - row.depth * 2}%)`,
                                                '--row-selected': `hsl(39, 12%, ${90 - row.depth * 2}%)`,
                                            } as React.CSSProperties
                                        }
                                        className="transition-colors duration-200 bg-[var(--row-base)] hover:bg-[var(--row-hover)] data-[state=selected]:bg-[var(--row-selected)] data-[state=selected]:hover:bg-[var(--row-hover)]"
                                        data-state={
                                            row.getIsSelected() && 'selected'
                                        }
                                    >
                                        {row
                                            .getVisibleCells()
                                            .filter(
                                                (cell) =>
                                                    !hideColumns.includes(
                                                        cell.column.id
                                                    )
                                            )
                                            .map((cell) => {
                                                return (
                                                    <TableCell
                                                        key={cell.id}
                                                        className={cn(
                                                            {
                                                                'pr-0':
                                                                    cell.column
                                                                        .id ===
                                                                    'expand',
                                                            },
                                                            cell.column
                                                                .columnDef.meta
                                                                ?.className
                                                        )}
                                                        style={{
                                                            width: `${Math.round((cell.column.getSize() / totalWidth) * 100)}%`,
                                                            ...(cell.column
                                                                .id === 'label'
                                                                ? {
                                                                      justifyContent:
                                                                          'flex-start',
                                                                      paddingLeft:
                                                                          1 +
                                                                          row.depth *
                                                                              1 +
                                                                          'rem',
                                                                      fontWeight:
                                                                          800 -
                                                                          row.depth *
                                                                              100,
                                                                  }
                                                                : {}),
                                                        }}
                                                    >
                                                        {flexRender(
                                                            cell.column
                                                                .columnDef.cell,
                                                            cell.getContext()
                                                        )}
                                                    </TableCell>
                                                )
                                            })}
                                    </TableRow>
                                )
                            })
                        ) : (
                            <TableRow>
                                <TableCell
                                    colSpan={table.getAllColumns().length}
                                    className="h-24 text-center"
                                >
                                    No results.
                                </TableCell>
                            </TableRow>
                        )}
                    </TableBody>
                </Table>
            )}

            <div className="flex flex-col gap-2.5 relative">
                {showFloatingBar && floatingBar}
            </div>
        </>
    )
}
