diff --git a/packages/react-use/src/use-raf-state/index.test.ts b/packages/react-use/src/use-raf-state/index.test.ts index 9383ea0..7914b79 100644 --- a/packages/react-use/src/use-raf-state/index.test.ts +++ b/packages/react-use/src/use-raf-state/index.test.ts @@ -32,35 +32,40 @@ describe('useRafState', () => { expect(result.current[0]).toBe(1) }) - it('should handle multiple updates correctly', () => { - const { result } = renderHook(() => useRafState(initialState)) - - act(() => { - result.current[1](2) - result.current[1](3) - }) - act(() => { - vi.advanceTimersToNextFrame() - }) - - expect(result.current[0]).toBe(3) - }) - it('should work with undefined initial state', () => { const { result } = renderHook(() => useRafState()) expect(result.current[0]).toBeUndefined() }) - it('should update state with a function', () => { + it.each([ + { + name: 'value updates', + updates: [2, 3], + expected: 3, + }, + { + name: 'function updates', + updates: [(prev: number) => prev + 2, (prev: number) => prev + 1], + expected: 3, + }, + { + name: 'mixed updates', + updates: [2, (prev: number) => prev + 1], + expected: 3, + }, + ])('should handle multiple updates correctly > $name', ({ updates, expected }) => { const { result } = renderHook(() => useRafState(initialState)) act(() => { - result.current[1]((prev) => prev + 1) + const [_, setState] = result.current + for (const update of updates) { + setState(update) + } }) act(() => { vi.advanceTimersToNextFrame() }) - expect(result.current[0]).toBe(1) + expect(result.current[0]).toBe(expected) }) }) diff --git a/packages/react-use/src/use-raf-state/index.ts b/packages/react-use/src/use-raf-state/index.ts index d7de7f3..b6854a4 100644 --- a/packages/react-use/src/use-raf-state/index.ts +++ b/packages/react-use/src/use-raf-state/index.ts @@ -1,5 +1,7 @@ +import { useCallback, useRef } from 'react' import { useRafFn } from '../use-raf-fn' import { useSafeState } from '../use-safe-state' +import { isFunction } from '../utils/basic' import type { ReactSetState, UseSafeStateOptions } from '../use-safe-state' import type { Gettable } from '../utils/basic' @@ -15,5 +17,17 @@ export function useRafState(initialState: Gettable, options?: UseRafStateO export function useRafState(): readonly [T | undefined, ReactSetState] export function useRafState(initialState?: Gettable, options?: UseRafStateOptions) { const [state, setState] = useSafeState(initialState, options) - return [state, useRafFn(setState)] as const + const stateRef = useRef(state) + + const scheduleSet = useRafFn(setState, true) + + const set = useCallback( + (value) => { + stateRef.current = isFunction(value) ? value(stateRef.current) : value + scheduleSet(() => stateRef.current) + }, + [scheduleSet], + ) + + return [state, set] as const }