diff --git a/README.md b/README.md index 054480c..f6b3b6a 100644 --- a/README.md +++ b/README.md @@ -561,6 +561,36 @@ const TableSettingsDemo = () => { Learn more about the table and the column resizing properties in the react-table [docs](https://tanstack.com/table/v8/docs/api/features/column-sizing) +## Memoization (experimental) + +Pass `experimentalMemoization` on `Table` / `BaseTable` to enable +`React.memo` on rows and cells. This avoids re-rendering every row and cell +when one row's state changes. The flag is opt-in; without it the rendering +behavior is unchanged. + +```tsx + +``` + +By default, the memo comparator tracks `row.getIsSelected()` and +`row.getIsExpanded()`. If your custom cells read other row state (or external +state keyed by row id), declare it via `getRowVersion`: + +```tsx +const getRowVersion = (row: Row) => + [row.getIsSelected(), row.getIsExpanded(), row.getIsPinned()] as const; + +
; +``` + +`getRowVersion` is called once per row per parent render. Returned values are +compared element-wise with `Object.is` — any change invalidates the row's memo +and re-renders only that row's cells. + +See [`docs/MIGRATION-experimentalMemoization.md`](docs/MIGRATION-experimentalMemoization.md) +for anti-patterns that defeat memoization, the verification recipe, and a +worked migration example. + ## Known Issues and Compatibility ### React 19 + React Compiler Compatibility diff --git a/src/components/BaseCell/BaseCell.memo.tsx b/src/components/BaseCell/BaseCell.memo.tsx new file mode 100644 index 0000000..2757ca7 --- /dev/null +++ b/src/components/BaseCell/BaseCell.memo.tsx @@ -0,0 +1,21 @@ +import * as React from 'react'; + +import {arraysShallowEqual} from '../../utils'; + +import type {BaseCellProps} from './BaseCell'; +import {BaseCell} from './BaseCell'; + +function areCellPropsEqual(prev: BaseCellProps, next: BaseCellProps): boolean { + return ( + arraysShallowEqual(prev._rowVersion ?? [], next._rowVersion ?? []) && + prev.cell === next.cell && + prev.className === next.className && + prev.attributes === next.attributes && + prev.style === next.style && + prev.children === next.children && + prev.colSpan === next.colSpan && + prev['aria-colindex'] === next['aria-colindex'] + ); +} + +export const MemoBaseCell = React.memo(BaseCell, areCellPropsEqual) as typeof BaseCell; diff --git a/src/components/BaseCell/BaseCell.tsx b/src/components/BaseCell/BaseCell.tsx index 55a2aad..af34265 100644 --- a/src/components/BaseCell/BaseCell.tsx +++ b/src/components/BaseCell/BaseCell.tsx @@ -13,6 +13,7 @@ export interface BaseCellProps attributes?: | React.TdHTMLAttributes | ((cell?: Cell) => React.TdHTMLAttributes); + _rowVersion?: readonly unknown[]; } export const BaseCell = ({ @@ -21,6 +22,7 @@ export const BaseCell = ({ className: classNameProp, style, attributes: attributesProp, + _rowVersion: _rowVersionDiscarded, ...restProps }: BaseCellProps) => { const attributes = typeof attributesProp === 'function' ? attributesProp(cell) : attributesProp; diff --git a/src/components/BaseCell/index.ts b/src/components/BaseCell/index.ts index 413c61f..2c9fa97 100644 --- a/src/components/BaseCell/index.ts +++ b/src/components/BaseCell/index.ts @@ -1 +1,2 @@ export * from './BaseCell'; +export * from './BaseCell.memo'; diff --git a/src/components/BaseDraggableRow/BaseDraggableRow.memo.tsx b/src/components/BaseDraggableRow/BaseDraggableRow.memo.tsx new file mode 100644 index 0000000..151422b --- /dev/null +++ b/src/components/BaseDraggableRow/BaseDraggableRow.memo.tsx @@ -0,0 +1,62 @@ +import * as React from 'react'; + +import {arraysShallowEqual} from '../../utils'; +import {MemoBaseCell} from '../BaseCell/BaseCell.memo'; + +import {BaseDraggableRow} from './BaseDraggableRow'; +import type {BaseDraggableRowProps} from './BaseDraggableRow'; + +export interface MemoBaseDraggableRowProps< + TData, + TScrollElement extends Element | Window = HTMLDivElement, +> extends BaseDraggableRowProps { + _rowVersion: readonly unknown[]; +} + +// eslint-disable-next-line complexity +function areEqual( + prev: Readonly>, + next: Readonly>, +): boolean { + return ( + prev.row === next.row && + arraysShallowEqual(prev._rowVersion ?? [], next._rowVersion ?? []) && + prev.table === next.table && + prev.virtualItem?.start === next.virtualItem?.start && + prev.virtualItem?.size === next.virtualItem?.size && + prev.style === next.style && + prev.cellClassName === next.cellClassName && + prev.className === next.className && + prev.onClick === next.onClick && + prev.getIsCustomRow === next.getIsCustomRow && + prev.getIsGroupHeaderRow === next.getIsGroupHeaderRow && + prev.renderCustomRowContent === next.renderCustomRowContent && + prev.renderGroupHeader === next.renderGroupHeader && + prev.renderGroupHeaderRowContent === next.renderGroupHeaderRowContent && + prev.getGroupTitle === next.getGroupTitle && + prev.groupHeaderClassName === next.groupHeaderClassName && + prev.attributes === next.attributes && + prev.cellAttributes === next.cellAttributes && + prev.rowVirtualizer === next.rowVirtualizer && + prev['aria-rowindex'] === next['aria-rowindex'] && + prev['aria-selected'] === next['aria-selected'] + ); +} + +const BaseDraggableRowWithMemoCell = React.forwardRef(function BaseDraggableRowWithMemoCellRender< + TData, + TScrollElement extends Element | Window, +>(props: BaseDraggableRowProps, ref: React.Ref) { + return ; +}) as ( + props: BaseDraggableRowProps & {ref?: React.Ref}, +) => React.ReactElement; + +export const MemoBaseDraggableRow = React.memo(BaseDraggableRowWithMemoCell, areEqual) as < + TData, + TScrollElement extends Element | Window = HTMLDivElement, +>( + props: MemoBaseDraggableRowProps & { + ref?: React.Ref; + }, +) => React.ReactElement; diff --git a/src/components/BaseDraggableRow/index.ts b/src/components/BaseDraggableRow/index.ts index dcd14ef..e24caa5 100644 --- a/src/components/BaseDraggableRow/index.ts +++ b/src/components/BaseDraggableRow/index.ts @@ -1 +1,2 @@ export * from './BaseDraggableRow'; +export * from './BaseDraggableRow.memo'; diff --git a/src/components/BaseGroupHeader/BaseGroupHeader.tsx b/src/components/BaseGroupHeader/BaseGroupHeader.tsx index 0088617..41ea0ca 100644 --- a/src/components/BaseGroupHeader/BaseGroupHeader.tsx +++ b/src/components/BaseGroupHeader/BaseGroupHeader.tsx @@ -1,4 +1,4 @@ -import type * as React from 'react'; +import * as React from 'react'; import type {Row} from '@tanstack/react-table'; @@ -17,11 +17,13 @@ export const BaseGroupHeader = ({ className, getGroupTitle, }: BaseGroupHeaderProps) => { + const isExpanded = row.getIsExpanded(); + return (

; + }; + + const {rerender} = render(); + + // Baseline: every cell rendered exactly once. + expect(cellRenderCount.get('a')).toBe(1); + expect(cellRenderCount.get('b')).toBe(1); + expect(cellRenderCount.get('c')).toBe(1); + + // Force a row-level re-render by changing the className. cell objects from + // tanstack are stable across this re-render, so MemoBaseCell should skip. + rerender(); + + expect(cellRenderCount.get('a')).toBe(1); + expect(cellRenderCount.get('b')).toBe(1); + expect(cellRenderCount.get('c')).toBe(1); + }); + + it('re-renders only the toggled row when row.getIsSelected changes (default getRowVersion)', () => { + // Capture the table instance so the test can drive state through TanStack + // without rendering UI controls. + const tableRef: {current: ReturnType> | null} = {current: null}; + + const Wrapper = () => { + const table = useTable({columns, data, getRowId, enableRowSelection: true}); + tableRef.current = table; + return
; + }; + + render(); + + expect(cellRenderCount.get('a')).toBe(1); + expect(cellRenderCount.get('b')).toBe(1); + expect(cellRenderCount.get('c')).toBe(1); + + // Toggle selection of row b. TanStack updates internal state and triggers + // a re-render of `Wrapper` (the component using useTable). The default + // getRowVersion includes getIsSelected, so b's row+cell version changes. + act(() => { + tableRef.current!.getRow('b').toggleSelected(); + }); + + // a and c: row reference unchanged + version unchanged → memo skips → cell skipped. + // b: version changed → row re-renders → cell re-renders. + expect(cellRenderCount.get('a')).toBe(1); + expect(cellRenderCount.get('b')).toBe(2); + expect(cellRenderCount.get('c')).toBe(1); + }); + + it('re-renders only the row whose custom version slice changed', () => { + // Custom external state keyed by row id, not on the row itself. + const flagged = new Map([ + ['a', false], + ['b', false], + ['c', false], + ]); + + const getRowVersion = (row: Row) => [flagged.get(row.id) ?? false] as const; + + const Wrapper = ({forceRerender: _}: {forceRerender: number}) => { + const table = useTable({columns, data, getRowId}); + return
; + }; + + const {rerender} = render(); + + expect(cellRenderCount.get('a')).toBe(1); + expect(cellRenderCount.get('b')).toBe(1); + expect(cellRenderCount.get('c')).toBe(1); + + // Mutate the external map for row b only, then force a parent re-render + // by changing the unrelated `forceRerender` prop. + flagged.set('b', true); + rerender(); + + // a and c had unchanged versions ([false]) — skipped. + // b's version changed ([false] -> [true]) — re-rendered. + expect(cellRenderCount.get('a')).toBe(1); + expect(cellRenderCount.get('b')).toBe(2); + expect(cellRenderCount.get('c')).toBe(1); + }); + + it('detects length changes in getRowVersion as a state change', () => { + let extended = false; + const getRowVersion = (row: Row) => + extended ? ([row.id, 'extra'] as readonly unknown[]) : ([row.id] as readonly unknown[]); + + const Wrapper = ({forceRerender: _}: {forceRerender: number}) => { + const table = useTable({columns, data, getRowId}); + return
; + }; + + const {rerender} = render(); + + expect(cellRenderCount.get('a')).toBe(1); + expect(cellRenderCount.get('b')).toBe(1); + expect(cellRenderCount.get('c')).toBe(1); + + // Switch to the longer-array variant. arraysShallowEqual sees a length + // mismatch and treats every row as changed. + extended = true; + rerender(); + + expect(cellRenderCount.get('a')).toBe(2); + expect(cellRenderCount.get('b')).toBe(2); + expect(cellRenderCount.get('c')).toBe(2); + }); +}); diff --git a/src/components/BaseRow/index.ts b/src/components/BaseRow/index.ts index ea6a4a0..c564e9c 100644 --- a/src/components/BaseRow/index.ts +++ b/src/components/BaseRow/index.ts @@ -1 +1,2 @@ export * from './BaseRow'; +export * from './BaseRow.memo'; diff --git a/src/components/BaseTable/BaseTable.tsx b/src/components/BaseTable/BaseTable.tsx index 4793e71..6bcf509 100644 --- a/src/components/BaseTable/BaseTable.tsx +++ b/src/components/BaseTable/BaseTable.tsx @@ -6,12 +6,15 @@ import type {VirtualItem, Virtualizer} from '@tanstack/react-virtual'; import type {HeaderGroup} from '../../types/base'; import {getAriaMultiselectable, getAriaRowIndexMap, shouldRenderFooterRow} from '../../utils'; import {BaseDraggableRow} from '../BaseDraggableRow'; +import {MemoBaseDraggableRow} from '../BaseDraggableRow/BaseDraggableRow.memo'; import type {BaseFooterRowProps} from '../BaseFooterRow'; import {BaseFooterRow} from '../BaseFooterRow'; import type {BaseHeaderRowProps} from '../BaseHeaderRow'; import {BaseHeaderRow} from '../BaseHeaderRow'; import type {BaseRowProps} from '../BaseRow'; import {BaseRow} from '../BaseRow'; +import type {MemoBaseRowProps} from '../BaseRow/BaseRow.memo'; +import {MemoBaseRow} from '../BaseRow/BaseRow.memo'; import {LastSelectedRowContextProvider} from '../LastSelectedRowContext'; import {SortableListContext} from '../SortableListContext'; @@ -19,6 +22,48 @@ import {b} from './BaseTable.classname'; import './BaseTable.scss'; +function resolveBodyRow( + virtualItemOrRow: VirtualItem | Row, + rows: Row[], + rowVirtualizer: Virtualizer | undefined, + index: number, +) { + const isVirtual = Boolean(rowVirtualizer); + + const row = isVirtual + ? rows[(virtualItemOrRow as VirtualItem).index] + : (virtualItemOrRow as Row); + const rowIndex = isVirtual ? (virtualItemOrRow as VirtualItem).index : index; + + const virtualItem = isVirtual ? (virtualItemOrRow as VirtualItem) : undefined; + + return {row, rowIndex, virtualItem, key: virtualItem?.key ?? row.id}; +} + +function getTreeStyle( + row: Row, + nextRow: Row | undefined, + cache: Map, +): React.CSSProperties | undefined { + if (row.depth === 0) return undefined; + + const lastNested = nextRow?.depth === 0 ? 1 : 0; + const cacheKey = `${row.depth}-${lastNested}`; + + if (!cache.has(cacheKey)) { + cache.set(cacheKey, { + '--_--tree-depth': row.depth, + '--_--last-nested': lastNested, + } as React.CSSProperties); + } + + return cache.get(cacheKey); +} + +function defaultGetRowVersion(row: Row): readonly unknown[] { + return [row.getIsSelected(), row.getIsExpanded()]; +} + export interface BaseTableProps { /** The table instance returned from the `useTable` hook */ table: Table; @@ -111,6 +156,17 @@ export interface BaseTableProps` element should be rendered */ withHeader?: boolean; + /** EXPERIMENTAL. Enables React.memo on rows and cells to avoid full-table re-renders */ + experimentalMemoization?: boolean; + /** + * EXPERIMENTAL. Snapshot of row state used by the memo comparator. + * Only relevant when `experimentalMemoization` is true. Returned values + * are compared element-wise via Object.is — anything your custom cells + * read from the row (or external state keyed by row id) should appear here. + * + * Default: (row) => [row.getIsSelected(), row.getIsExpanded()] + */ + getRowVersion?: (row: Row) => readonly unknown[]; } export const BaseTable = React.forwardRef( @@ -159,10 +215,14 @@ export const BaseTable = React.forwardRef( stickyHeader = false, withFooter = false, withHeader = true, + experimentalMemoization = false, + getRowVersion, }: BaseTableProps, ref: React.Ref, ) => { const draggableContext = React.useContext(SortableListContext); + const memoStyleCache = React.useRef>(new Map()); + const draggingRowIndex = draggableContext?.activeItemIndex ?? -1; const {rows, rowsById} = table.getRowModel(); @@ -212,26 +272,21 @@ export const BaseTable = React.forwardRef( ); }; + const resolveRowVersion: (row: Row) => readonly unknown[] = + getRowVersion ?? defaultGetRowVersion; + const renderBodyRows = () => { return bodyRows.map((virtualItemOrRow, index) => { - const row = rowVirtualizer - ? rows[virtualItemOrRow.index] - : (virtualItemOrRow as Row); - - const rowIndex = rowVirtualizer ? virtualItemOrRow.index : index; - - const style = - row.depth > 0 - ? { - '--_--tree-depth': row.depth, - '--_--last-nested': rows[rowIndex + 1]?.depth === 0 ? 1 : 0, - } - : undefined; + const {row, rowIndex, virtualItem, key} = resolveBodyRow( + virtualItemOrRow, + rows, + rowVirtualizer, + index, + ); - const virtualItem = rowVirtualizer ? (virtualItemOrRow as VirtualItem) : undefined; - const key = virtualItem?.key ?? row.id; + const isSelected = table.options.enableRowSelection ? row.getIsSelected() : false; - const rowProps: BaseRowProps = { + const baseProps: BaseRowProps = { cellClassName, className: rowClassName, getGroupTitle, @@ -248,18 +303,34 @@ export const BaseTable = React.forwardRef( rowVirtualizer, table, virtualItem, - style, + style: + row.depth > 0 + ? { + '--_--tree-depth': row.depth, + '--_--last-nested': rows[rowIndex + 1]?.depth === 0 ? 1 : 0, + } + : undefined, 'aria-rowindex': headerRowCount + ariaRowIndexMap[row.id], - 'aria-selected': table.options.enableRowSelection - ? row.getIsSelected() - : undefined, + 'aria-selected': table.options.enableRowSelection ? isSelected : undefined, }; - if (draggableContext) { - return ; + if (!experimentalMemoization) { + if (draggableContext) { + return ; + } + return ; } - return ; + const memoizedProps: MemoBaseRowProps = { + ...baseProps, + style: getTreeStyle(row, rows[rowIndex + 1], memoStyleCache.current), + _rowVersion: resolveRowVersion(row), + }; + + if (draggableContext) { + return ; + } + return ; }); }; diff --git a/src/components/Table/__stories__/Table.stories.tsx b/src/components/Table/__stories__/Table.stories.tsx index eae71df..485f9d7 100644 --- a/src/components/Table/__stories__/Table.stories.tsx +++ b/src/components/Table/__stories__/Table.stories.tsx @@ -6,6 +6,7 @@ import {DefaultStory} from './stories/DefaultStory'; import {FilteringStory} from './stories/FilteringStory'; import {GroupingStory} from './stories/GroupingStory'; import {GroupingWithSelectionStory} from './stories/GroupingWithSelectionStory'; +import {RenderCountTreeStory} from './stories/RenderCountTreeStory'; import {ReorderingStory} from './stories/ReorderingStory'; import {ReorderingWithVirtualizationStory} from './stories/ReorderingWithVirtualizationStory'; import {RowLinkStory} from './stories/RowLinkStory'; @@ -108,3 +109,8 @@ export const Grouping: StoryObj = { export const GroupingWithSelection: StoryObj = { render: GroupingWithSelectionStory, }; + +export const RenderCountTree: StoryObj = { + render: RenderCountTreeStory, + name: 'Experimental: Render Count (memoization demo)', +}; diff --git a/src/components/Table/__stories__/stories/RenderCountTreeStory.tsx b/src/components/Table/__stories__/stories/RenderCountTreeStory.tsx new file mode 100644 index 0000000..7b3518d --- /dev/null +++ b/src/components/Table/__stories__/stories/RenderCountTreeStory.tsx @@ -0,0 +1,135 @@ +import * as React from 'react'; + +import type {ColumnDef, ExpandedState} from '@tanstack/react-table'; + +import {useTable} from '../../../../hooks'; +import {Table} from '../../Table'; +import type {TableProps} from '../../Table'; + +interface Item { + id: string; + name: string; + renderCount?: number; + children?: Item[]; +} + +// Build a large tree: 50 root rows each with 20 children = 1050 rows total +function buildData(): Item[] { + const items: Item[] = []; + for (let i = 0; i < 50; i++) { + const children: Item[] = []; + for (let j = 0; j < 20; j++) { + children.push({id: `${i}-${j}`, name: `Row ${i}.${j}`}); + } + items.push({id: `${i}`, name: `Group ${i}`, children}); + } + return items; +} + +const data = buildData(); + +const NameCell = ({ + row, + value, +}: { + row: import('@tanstack/react-table').Row; + value: string; +}) => { + const isExpanded = row.getIsExpanded(); + + const renderCountRef = React.useRef(0); + renderCountRef.current += 1; + const renderCount = renderCountRef.current; + + return ( +
+ {row.getCanExpand() && ( + + )} + {value} + 1 ? 'red' : 'green', + fontFamily: 'monospace', + }} + > + renders: {renderCount} + +
+ ); +}; + +const getColumns = (memo: boolean): ColumnDef[] => [ + { + accessorKey: 'name', + header: `Name (${memo ? 'memo ON' : 'memo OFF'})`, + size: 400, + cell: (info) => ()} />, + }, + {accessorKey: 'id', header: 'ID', size: 120}, +]; + +export const RenderCountTreeStory = (props: Partial>) => { + const [experimentalMemoization, setExperimentalMemoization] = React.useState(false); + const [expanded, setExpanded] = React.useState({}); + + const columns = React.useMemo( + () => getColumns(experimentalMemoization), + [experimentalMemoization], + ); + + const table = useTable({ + columns, + data, + getSubRows: (item) => item.children, + enableExpanding: true, + onExpandedChange: setExpanded, + state: {expanded}, + }); + + return ( +
+
+ + + Toggle a row and watch the render counters — green = rendered once, red = + re-rendered. With memo ON, only the toggled row and its children should + increment. + +
+
+
+ + + ); +}; diff --git a/src/components/TreeExpandableCell/TreeExpandableCell.tsx b/src/components/TreeExpandableCell/TreeExpandableCell.tsx index 6f446c9..b9f95b5 100644 --- a/src/components/TreeExpandableCell/TreeExpandableCell.tsx +++ b/src/components/TreeExpandableCell/TreeExpandableCell.tsx @@ -10,6 +10,8 @@ export interface TreeExpandableCellProps extends React.PropsWithChildren } export const TreeExpandableCell = ({row, children}: TreeExpandableCellProps) => { + const isExpanded = row.getIsExpanded(); + return ( {children} diff --git a/src/index.ts b/src/index.ts index ab0df41..f895808 100644 --- a/src/index.ts +++ b/src/index.ts @@ -51,6 +51,9 @@ export type { UseColumnsAutoSizeProps, } from './hooks'; +export {MemoBaseRow, MemoBaseCell, MemoBaseDraggableRow} from './components'; +export type {MemoBaseRowProps, MemoBaseDraggableRowProps} from './components'; + export {getVirtualRowRangeExtractor} from './utils'; export type { diff --git a/src/utils/__tests__/arraysShallowEqual.test.ts b/src/utils/__tests__/arraysShallowEqual.test.ts new file mode 100644 index 0000000..3063ed7 --- /dev/null +++ b/src/utils/__tests__/arraysShallowEqual.test.ts @@ -0,0 +1,39 @@ +import {arraysShallowEqual} from '../arraysShallowEqual'; + +describe('arraysShallowEqual', () => { + it('returns true for the same reference', () => { + const a = [1, 'x', null]; + expect(arraysShallowEqual(a, a)).toBe(true); + }); + + it('returns true for arrays with element-wise equal primitives', () => { + expect(arraysShallowEqual([1, 'x', true], [1, 'x', true])).toBe(true); + }); + + it('returns false when lengths differ', () => { + expect(arraysShallowEqual([1, 2], [1, 2, 3])).toBe(false); + expect(arraysShallowEqual([1, 2, 3], [1, 2])).toBe(false); + }); + + it('returns false when any element differs by Object.is', () => { + expect(arraysShallowEqual([1, 'x', true], [1, 'x', false])).toBe(false); + }); + + it('uses Object.is, not ===, for NaN', () => { + expect(arraysShallowEqual([NaN], [NaN])).toBe(true); + }); + + it('uses Object.is, not ===, for +0 / -0', () => { + expect(arraysShallowEqual([0], [-0])).toBe(false); + }); + + it('returns true for two empty arrays', () => { + expect(arraysShallowEqual([], [])).toBe(true); + }); + + it('compares object references, not deep contents', () => { + const obj = {a: 1}; + expect(arraysShallowEqual([obj], [obj])).toBe(true); + expect(arraysShallowEqual([{a: 1}], [{a: 1}])).toBe(false); + }); +}); diff --git a/src/utils/arraysShallowEqual.ts b/src/utils/arraysShallowEqual.ts new file mode 100644 index 0000000..405fa59 --- /dev/null +++ b/src/utils/arraysShallowEqual.ts @@ -0,0 +1,8 @@ +export function arraysShallowEqual(a: readonly unknown[], b: readonly unknown[]): boolean { + if (a === b) return true; + if (a.length !== b.length) return false; + for (let i = 0; i < a.length; i++) { + if (!Object.is(a[i], b[i])) return false; + } + return true; +} diff --git a/src/utils/index.ts b/src/utils/index.ts index bfd3e34..e2990a7 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -1,3 +1,4 @@ +export * from './arraysShallowEqual'; export * from './cn'; export * from './getAriaMultiselectable'; export * from './getAriaRowIndexMap';