diff --git a/src/vanilla/atom.ts b/src/vanilla/atom.ts index d0907329ea9..c070779e002 100644 --- a/src/vanilla/atom.ts +++ b/src/vanilla/atom.ts @@ -1,3 +1,5 @@ +import type { AtomState, PrdOrDevStore as Store } from './store' + type Getter = (atom: Atom) => Value type Setter = ( @@ -47,6 +49,11 @@ export interface Atom { * @private */ debugPrivate?: boolean + /** + * Fires after atom is referenced by the store for the first time + * For internal use only and subject to change without notice. + */ + INTERNAL_onInit?: (store: Store, atomState: AtomState) => void } export interface WritableAtom diff --git a/src/vanilla/store.ts b/src/vanilla/store.ts index 0a7d7e368ab..3d36a75d423 100644 --- a/src/vanilla/store.ts +++ b/src/vanilla/store.ts @@ -62,6 +62,8 @@ const isPromiseLike = ( ): x is PromiseLike & { onCancel?: (fn: CancelHandler) => void } => typeof (x as any)?.then === 'function' +type BatchListener = (batch: Batch) => void + /** * State tracked for mounted atoms. An atom is considered "mounted" if it has a * subscriber, or is a transitive dependency of another atom that has a @@ -70,26 +72,28 @@ const isPromiseLike = ( * The mounted state of an atom is freed once it is no longer mounted. */ type Mounted = { - /** Set of listeners to notify when the atom value changes. */ - readonly l: Set<() => void> + /** Count of listeners to notify when the atom value changes. */ + l: number /** Set of mounted atoms that the atom depends on. */ readonly d: Set /** Set of mounted atoms that depends on the atom. */ readonly t: Set /** Function to run when the atom is unmounted. */ - u?: (batch: Batch) => void + u?: BatchListener } /** * Mutable atom state, * tracked for both mounted and unmounted atoms in a store. */ -type AtomState = { +export type AtomState = { /** * Map of atoms that the atom depends on. * The map value is the epoch number of the dependency. */ readonly d: Map + /** Set of priority listeners to run when the atom value changes. */ + readonly l: Set /** * Set of atoms with pending promise that depend on the atom. * @@ -169,11 +173,11 @@ type Batch = Readonly<{ /** Atom dependents map */ D: Map> /** High priority functions */ - H: Set<() => void> + H: Set /** Medium priority functions */ - M: Set<() => void> + M: Set /** Low priority functions */ - L: Set<() => void> + L: Set }> const createBatch = (): Batch => ({ @@ -185,8 +189,8 @@ const createBatch = (): Batch => ({ const addBatchFunc = ( batch: Batch, + fn: BatchListener, priority: BatchPriority, - fn: () => void, ) => { batch[priority].add(fn) } @@ -198,9 +202,12 @@ const registerBatchAtom = ( ) => { if (!batch.D.has(atom)) { batch.D.set(atom, new Set()) - addBatchFunc(batch, 'M', () => { - atomState.m?.l.forEach((listener) => addBatchFunc(batch, 'M', listener)) - }) + const scheduleListeners = () => { + for (const [listener, priority] of atomState.l) { + addBatchFunc(batch, listener, priority) + } + } + addBatchFunc(batch, scheduleListeners, 'H') } } @@ -221,9 +228,9 @@ const getBatchAtomDependents = (batch: Batch, atom: AnyAtom) => const flushBatch = (batch: Batch) => { let error: AnyError let hasError = false - const call = (fn: () => void) => { + const call = (fn: BatchListener) => { try { - fn() + fn(batch) } catch (e) { if (!hasError) { error = e @@ -245,9 +252,17 @@ const flushBatch = (batch: Batch) => { } } +type AtomOnInit = ( + atom: Atom, + atomState: AtomState, +) => void + // internal & unstable type type StoreArgs = readonly [ - getAtomState: (atom: Atom) => AtomState, + getAtomState: ( + atom: Atom, + atomOnInit?: AtomOnInit | undefined, + ) => AtomState, atomRead: ( atom: Atom, ...params: Parameters['read']> @@ -260,6 +275,7 @@ type StoreArgs = readonly [ atom: WritableAtom, setAtom: (...args: Args) => Result, ) => OnUnmount | void, + createAtomOnInit: (store: Store) => AtomOnInit, ] // for debugging purpose only @@ -271,7 +287,7 @@ type DevStoreRev4 = { dev4_restore_atoms: (values: Iterable) => void } -type PrdStore = { +type Store = { get: (atom: Atom) => Value set: ( atom: WritableAtom, @@ -281,20 +297,14 @@ type PrdStore = { unstable_derive: (fn: (...args: StoreArgs) => StoreArgs) => Store } -type Store = PrdStore | (PrdStore & DevStoreRev4) - export type INTERNAL_DevStoreRev4 = DevStoreRev4 -export type INTERNAL_PrdStore = PrdStore - -const buildStore = ( - ...[getAtomState, atomRead, atomWrite, atomOnMount]: StoreArgs -): Store => { - // for debugging purpose only - let debugMountedAtoms: Set +export type INTERNAL_PrdStore = Store - if (import.meta.env?.MODE !== 'production') { - debugMountedAtoms = new Set() - } +const buildStore = (...storeArgs: StoreArgs): Store => { + const [_getAtomState, atomRead, atomWrite, atomOnMount, createAtomOnInit] = + storeArgs + const getAtomState = (atom: Atom) => + _getAtomState(atom, createAtomOnInit(store)) const setAtomStateValueOrPromise = ( atom: AnyAtom, @@ -512,7 +522,7 @@ const buildStore = ( // Step 2: use the topSortedReversed atom list to recompute all affected atoms // Track what's changed, so that we can short circuit when possible - addBatchFunc(batch, 'H', () => { + const finishRecompute = () => { const changedAtoms = new Set([atom]) for (let i = topSortedReversed.length - 1; i >= 0; --i) { const [a, aState, prevEpochNumber] = topSortedReversed[i]! @@ -533,7 +543,8 @@ const buildStore = ( } delete aState.x } - }) + } + addBatchFunc(batch, finishRecompute, 'H') } const writeAtomState = ( @@ -630,13 +641,10 @@ const buildStore = ( } // mount self atomState.m = { - l: new Set(), d: new Set(atomState.d.keys()), + l: 0, t: new Set(), } - if (import.meta.env?.MODE !== 'production') { - debugMountedAtoms.add(atom) - } if (isActuallyWritableAtom(atom)) { const mounted = atomState.m let setAtom: (...args: unknown[]) => unknown @@ -657,14 +665,15 @@ const buildStore = ( isSync = false } } - addBatchFunc(batch, 'L', () => { + const processOnMount = () => { const onUnmount = createInvocationContext(batch, () => atomOnMount(atom, (...args) => setAtom(...args)), ) if (onUnmount) { mounted.u = (batch) => createInvocationContext(batch, onUnmount) } - }) + } + addBatchFunc(batch, processOnMount, 'L') } } return atomState.m @@ -677,18 +686,15 @@ const buildStore = ( ): Mounted | undefined => { if ( atomState.m && - !atomState.m.l.size && + !atomState.m.l && !Array.from(atomState.m.t).some((a) => getAtomState(a).m?.d.has(atom)) ) { // unmount self const onUnmount = atomState.m.u if (onUnmount) { - addBatchFunc(batch, 'L', () => onUnmount(batch)) + addBatchFunc(batch, onUnmount, 'L') } delete atomState.m - if (import.meta.env?.MODE !== 'production') { - debugMountedAtoms.delete(atom) - } // unmount dependencies for (const a of atomState.d.keys()) { const aMounted = unmountAtom(batch, a, getAtomState(a)) @@ -703,19 +709,21 @@ const buildStore = ( const batch = createBatch() const atomState = getAtomState(atom) const mounted = mountAtom(batch, atom, atomState) - const listeners = mounted.l - listeners.add(listener) + const priorityListener = [() => listener(), 'M'] as const + ++mounted.l + atomState.l.add(priorityListener) flushBatch(batch) return () => { - listeners.delete(listener) const batch = createBatch() + --mounted.l + atomState.l.delete(priorityListener) unmountAtom(batch, atom, atomState) flushBatch(batch) } } - const unstable_derive = (fn: (...args: StoreArgs) => StoreArgs) => - buildStore(...fn(getAtomState, atomRead, atomWrite, atomOnMount)) + const unstable_derive: Store['unstable_derive'] = (fn) => + buildStore(...fn(...storeArgs)) const store: Store = { get: readAtom, @@ -723,66 +731,120 @@ const buildStore = ( sub: subscribeAtom, unstable_derive, } - if (import.meta.env?.MODE !== 'production') { - const devStore: DevStoreRev4 = { - // store dev methods (these are tentative and subject to change without notice) - dev4_get_internal_weak_map: () => ({ - get: (atom) => { - const atomState = getAtomState(atom) - if (atomState.n === 0) { - // for backward compatibility - return undefined + return store +} + +const deriveDevStoreRev4 = (store: Store): Store & DevStoreRev4 => { + const proxyAtomStateMap = new WeakMap() + const debugMountedAtoms = new Set() + let savedGetAtomState: StoreArgs[0] + let inRestoreAtom = 0 + const derivedStore = store.unstable_derive( + (getAtomState, atomRead, atomWrite, atomOnMount, createAtomOnInit) => { + savedGetAtomState = (a) => getAtomState(a, createAtomOnInit(derivedStore)) + return [ + (atom, atomOnInit) => { + let proxyAtomState = proxyAtomStateMap.get(atom) + if (!proxyAtomState) { + const atomState = getAtomState(atom, atomOnInit) + proxyAtomState = new Proxy(atomState, { + set(target, prop, value) { + if (prop === 'm') { + debugMountedAtoms.add(atom) + } + return Reflect.set(target, prop, value) + }, + deleteProperty(target, prop) { + if (prop === 'm') { + debugMountedAtoms.delete(atom) + } + return Reflect.deleteProperty(target, prop) + }, + }) + proxyAtomStateMap.set(atom, proxyAtomState) } - return atomState + return proxyAtomState }, - }), - dev4_get_mounted_atoms: () => debugMountedAtoms, - dev4_restore_atoms: (values) => { - const batch = createBatch() - for (const [atom, value] of values) { - if (hasInitialValue(atom)) { - const atomState = getAtomState(atom) - const prevEpochNumber = atomState.n - setAtomStateValueOrPromise(atom, atomState, value) - mountDependencies(batch, atom, atomState) - if (prevEpochNumber !== atomState.n) { - registerBatchAtom(batch, atom, atomState) - recomputeDependents(batch, atom, atomState) - } + atomRead, + (atom, getter, setter, ...args) => { + if (inRestoreAtom) { + return setter(atom, ...args) } + return atomWrite(atom, getter, setter, ...args) + }, + atomOnMount, + createAtomOnInit, + ] + }, + ) + const savedStoreSet = derivedStore.set + const devStore: DevStoreRev4 = { + // store dev methods (these are tentative and subject to change without notice) + dev4_get_internal_weak_map: () => ({ + get: (atom) => { + const atomState = savedGetAtomState(atom) + if (atomState.n === 0) { + // for backward compatibility + return undefined } - flushBatch(batch) + return atomState }, - } - Object.assign(store, devStore) + }), + dev4_get_mounted_atoms: () => debugMountedAtoms, + dev4_restore_atoms: (values) => { + const restoreAtom: WritableAtom = { + read: () => null, + write: (_get, set) => { + ++inRestoreAtom + try { + for (const [atom, value] of values) { + if (hasInitialValue(atom)) { + set(atom as never, value) + } + } + } finally { + --inRestoreAtom + } + }, + } + savedStoreSet(restoreAtom) + }, } - return store + return Object.assign(derivedStore, devStore) } -export const createStore = (): Store => { +export type PrdOrDevStore = Store | (Store & DevStoreRev4) + +export const createStore = (): PrdOrDevStore => { const atomStateMap = new WeakMap() - const getAtomState = (atom: Atom) => { + const getAtomState = (atom: Atom, atomOnInit?: AtomOnInit) => { if (import.meta.env?.MODE !== 'production' && !atom) { throw new Error('Atom is undefined or null') } let atomState = atomStateMap.get(atom) as AtomState | undefined if (!atomState) { - atomState = { d: new Map(), p: new Set(), n: 0 } + atomState = { d: new Map(), l: new Set(), p: new Set(), n: 0 } atomStateMap.set(atom, atomState) + atomOnInit?.(atom, atomState) } return atomState } - return buildStore( + const store = buildStore( getAtomState, (atom, ...params) => atom.read(...params), (atom, ...params) => atom.write(...params), (atom, ...params) => atom.onMount?.(...params), + (store) => (atom, atomState) => atom.INTERNAL_onInit?.(store, atomState), ) + if (import.meta.env?.MODE !== 'production') { + return deriveDevStoreRev4(store) + } + return store } -let defaultStore: Store | undefined +let defaultStore: PrdOrDevStore | undefined -export const getDefaultStore = (): Store => { +export const getDefaultStore = (): PrdOrDevStore => { if (!defaultStore) { defaultStore = createStore() if (import.meta.env?.MODE !== 'production') { diff --git a/tests/setup.ts b/tests/setup.ts index a9d0dd31aa6..285ae1ca290 100644 --- a/tests/setup.ts +++ b/tests/setup.ts @@ -1 +1,17 @@ import '@testing-library/jest-dom/vitest' +import { expect, vi } from 'vitest' + +type MockFunction = ReturnType + +expect.extend({ + toHaveBeenCalledBefore(received: MockFunction, expected: MockFunction) { + const pass = + received.mock.invocationCallOrder[0]! < + expected.mock.invocationCallOrder[0]! + return { + pass, + message: () => + `expected ${received} to have been called before ${expected}`, + } + }, +}) diff --git a/tests/vanilla/effect.test.ts b/tests/vanilla/effect.test.ts new file mode 100644 index 00000000000..8acfc8e8a29 --- /dev/null +++ b/tests/vanilla/effect.test.ts @@ -0,0 +1,236 @@ +import { expect, it, vi } from 'vitest' +import type { Atom, Getter, Setter } from 'jotai/vanilla' +import { atom, createStore } from 'jotai/vanilla' + +type AnyAtom = Atom +type GetterWithPeak = Getter & { peak: Getter } +type SetterWithRecurse = Setter & { recurse: Setter } +type Cleanup = () => void +type Effect = (get: GetterWithPeak, set: SetterWithRecurse) => void | Cleanup +type Ref = { + get: GetterWithPeak + set?: SetterWithRecurse + cleanup?: Cleanup | null + fromCleanup: boolean + inProgress: number + isPending: boolean + deps: Set + sub: () => () => void + epoch: number +} + +type INTERNAL_onInit = NonNullable +type AtomState = Parameters[1] +type BatchListeners = NonNullable +type BatchEntry = BatchListeners extends Set ? U : never +type BatchListener = BatchEntry[0] +type BatchPriority = NonNullable + +function atomSyncEffect(effect: Effect) { + const refAtom = atom( + () => ({ deps: new Set(), inProgress: 0, epoch: 0 }) as Ref, + (get, set) => { + const ref = get(refAtom) + if (!ref.get.peak) { + ref.get.peak = get + } + const setter: Setter = (a, ...args) => { + try { + ++ref.inProgress + return set(a, ...args) + } finally { + --ref.inProgress + } + } + const recurse: Setter = (a, ...args) => { + if (ref.fromCleanup) { + if (import.meta.env?.MODE !== 'production') { + console.warn('cannot recurse inside cleanup') + } + return undefined as any + } + return set(a, ...args) + } + if (!ref.set) { + ref.set = Object.assign(setter, { recurse }) + } + ref.isPending = ref.inProgress === 0 + return () => { + ref.cleanup?.() + ref.cleanup = null + ref.isPending = false + ref.deps.clear() + } + }, + ) + refAtom.onMount = (mount) => mount() + const refreshAtom = atom(0) + const internalAtom = atom( + (get) => { + get(refreshAtom) + const ref = get(refAtom) + if (!ref.get) { + ref.get = ((a) => { + ref.deps.add(a) + return get(a) + }) as Getter & { peak: Getter } + } + ref.deps.forEach(get) + ref.isPending = true + return ++ref.epoch + }, + (get, set) => { + set(refreshAtom, (v) => ++v) + return get(refAtom).sub() + }, + ) + internalAtom.onMount = (mount) => mount() + internalAtom.INTERNAL_onInit = (store, atomState) => { + store.get(refAtom).sub = function subscribe() { + const batchListener: BatchListener = (_batch) => { + const ref = store.get(refAtom) + if (!ref.isPending || ref.inProgress > 0) { + return + } + ref.isPending = false + ref.cleanup?.() + const cleanup = effectAtom.effect(ref.get!, ref.set!) + ref.cleanup = + typeof cleanup === 'function' + ? () => { + try { + ref.fromCleanup = true + cleanup() + } finally { + ref.fromCleanup = false + } + } + : null + } + const priority: BatchPriority = 'H' + const priorityListener = [batchListener, priority] as const + atomState.l.add(priorityListener) + return () => atomState.l.delete(priorityListener) + } + } + const effectAtom = Object.assign( + atom((get) => void get(internalAtom)), + { effect }, + ) + return effectAtom +} + +it('responds to changes to atoms when subscribed', () => { + const store = createStore() + const a = atom(1) + const b = atom(1) + const w = atom(null, (_get, set, value: number) => { + set(a, value) + set(b, value) + }) + const results: number[] = [] + const cleanup = vi.fn() + const effect = vi.fn((get: Getter) => { + results.push(get(a) * 10 + get(b)) + return cleanup + }) + const e = atomSyncEffect(effect) + const unsub = store.sub(e, () => {}) // mount syncEffect + expect(effect).toBeCalledTimes(1) + expect(results).toStrictEqual([11]) // initial values at time of effect mount + store.set(a, 2) + expect(results).toStrictEqual([11, 21]) + store.set(b, 2) + expect(results).toStrictEqual([11, 21, 22]) + store.set(w, 3) + // intermediate state of '32' should not be recorded since the effect runs _after_ graph has been computed + expect(results).toStrictEqual([11, 21, 22, 33]) + expect(cleanup).toBeCalledTimes(3) + expect(effect).toBeCalledTimes(4) + unsub() + expect(cleanup).toBeCalledTimes(4) + expect(effect).toBeCalledTimes(4) + store.set(a, 4) + // the effect is unmounted so no more updates + expect(results).toStrictEqual([11, 21, 22, 33]) + expect(effect).toBeCalledTimes(4) +}) + +it('responds to changes to atoms when mounted with get', () => { + const store = createStore() + const a = atom(1) + const b = atom(1) + const w = atom(null, (_get, set, value: number) => { + set(a, value) + set(b, value) + }) + const results: number[] = [] + const cleanup = vi.fn() + const effect = vi.fn((get: Getter) => { + results.push(get(a) * 10 + get(b)) + return cleanup + }) + const e = atomSyncEffect(effect) + const d = atom((get) => get(e)) + const unsub = store.sub(d, () => {}) // mount syncEffect + expect(effect).toBeCalledTimes(1) + expect(results).toStrictEqual([11]) // initial values at time of effect mount + store.set(a, 2) + expect(results).toStrictEqual([11, 21]) + store.set(b, 2) + expect(results).toStrictEqual([11, 21, 22]) + store.set(w, 3) + // intermediate state of '32' should not be recorded since the effect runs _after_ graph has been computed + expect(results).toStrictEqual([11, 21, 22, 33]) + expect(cleanup).toBeCalledTimes(3) + expect(effect).toBeCalledTimes(4) + unsub() + expect(cleanup).toBeCalledTimes(4) + expect(effect).toBeCalledTimes(4) +}) + +it('sets values to atoms without causing infinite loop', () => { + const store = createStore() + const a = atom(1) + const effect = vi.fn((get: Getter, set: Setter) => { + set(a, get(a) + 1) + }) + const e = atomSyncEffect(effect) + const unsub = store.sub(e, () => {}) // mount syncEffect + expect(effect).toBeCalledTimes(1) + expect(store.get(a)).toBe(2) // initial values at time of effect mount + store.set(a, (v) => ++v) + expect(store.get(a)).toBe(4) + expect(effect).toBeCalledTimes(2) + unsub() + expect(effect).toBeCalledTimes(2) +}) + +it('reads the value with peak without subscribing to updates', () => { + const store = createStore() + const a = atom(1) + let result = 0 + const effect = vi.fn((get: GetterWithPeak) => { + result = get.peak(a) + }) + const e = atomSyncEffect(effect) + store.sub(e, () => {}) // mount syncEffect + expect(effect).toBeCalledTimes(1) + expect(result).toBe(1) // initial values at time of effect mount + store.set(a, 2) + expect(effect).toBeCalledTimes(1) +}) + +it('supports recursion', () => { + const store = createStore() + const a = atom(1) + const effect = vi.fn((get: Getter, set: SetterWithRecurse) => { + if (get(a) < 3) { + set.recurse(a, (v) => ++v) + } + }) + const e = atomSyncEffect(effect) + store.sub(e, () => {}) // mount syncEffect + expect(effect).toBeCalledTimes(3) + expect(store.get(a)).toBe(3) +}) diff --git a/tests/vanilla/store.test.tsx b/tests/vanilla/store.test.tsx index 5a186b0a3f0..428bd2549cf 100644 --- a/tests/vanilla/store.test.tsx +++ b/tests/vanilla/store.test.tsx @@ -1005,3 +1005,127 @@ it('should process all atom listeners even if some of them throw errors', () => expect(listenerB).toHaveBeenCalledTimes(1) expect(listenerC).toHaveBeenCalledTimes(1) }) + +it('should call onInit only once per atom', () => { + const store = createStore() + const a = atom(0) + const onInit = vi.fn() + a.INTERNAL_onInit = onInit + store.get(a) + expect(onInit).toHaveBeenCalledTimes(1) + const aAtomState = expect.objectContaining({ + d: expect.any(Map), + p: expect.any(Set), + n: expect.any(Number), + }) + expect(onInit).toHaveBeenCalledWith(store, aAtomState) + onInit.mockClear() + store.get(a) + store.set(a, 1) + const unsub = store.sub(a, () => {}) + unsub() + const b = atom((get) => get(a)) + store.get(b) + store.sub(b, () => {}) + expect(onInit).not.toHaveBeenCalled() +}) + +it('should call onInit only once per store', () => { + const a = atom(0) + type AtomState = Parameters['INTERNAL_onInit']>>[1] + let aAtomState: AtomState + const aOnInit = vi.fn((_store: Store, atomState: AtomState) => { + aAtomState = atomState + }) + a.INTERNAL_onInit = aOnInit + const b = atom(0) + let bAtomState: AtomState + const bOnInit = vi.fn((_store: Store, atomState: AtomState) => { + bAtomState = atomState + }) + b.INTERNAL_onInit = bOnInit + type Store = ReturnType + function testInStore(store: Store) { + store.get(a) + store.get(b) + const mockAtomState = expect.objectContaining({ + d: expect.any(Map), + p: expect.any(Set), + n: expect.any(Number), + }) + expect(aOnInit).toHaveBeenCalledTimes(1) + expect(bOnInit).toHaveBeenCalledTimes(1) + expect(aOnInit).toHaveBeenCalledWith(store, mockAtomState) + expect(bOnInit).toHaveBeenCalledWith(store, mockAtomState) + expect(aAtomState).not.toBe(bAtomState) + aOnInit.mockClear() + bOnInit.mockClear() + return store + } + testInStore(createStore()) + const store = testInStore(createStore()) + testInStore( + store.unstable_derive( + (getAtomState, readAtom, writeAtom, atomOnMount, atomOnInit) => { + const initializedAtoms = new WeakSet() + return [ + (a, atomOnInit) => { + const atomState = getAtomState(a) + if (!initializedAtoms.has(a)) { + initializedAtoms.add(a) + atomOnInit?.(a, atomState) + } + return atomState + }, + readAtom, + writeAtom, + atomOnMount, + atomOnInit, + ] + }, + ) as Store, + ) +}) + +it('should pass store and atomState to the atom initializer', () => { + expect.assertions(2) + const store = createStore() + const a = atom(null) + a.INTERNAL_onInit = (store, atomState) => { + expect(store).toBe(store) + expect(atomState).toEqual(expect.objectContaining({})) + } + store.get(a) +}) + +it('should call the batch listener with batch and respect the priority', () => { + type INTERNAL_onInit = NonNullable['INTERNAL_onInit']> + type AtomState = Parameters[1] + type BatchListeners = NonNullable + type BatchEntry = BatchListeners extends Set ? U : never + type BatchListener = BatchEntry[0] + + const a = atom(0) + const highPriorityBatchListener = vi.fn() as BatchListener + const mediumPriorityBatchListener = vi.fn() as BatchListener + const lowPriorityBatchListener = vi.fn() as BatchListener + a.INTERNAL_onInit = (_store, atomState) => { + atomState.l.add([lowPriorityBatchListener, 'L']) + atomState.l.add([highPriorityBatchListener, 'H']) + atomState.l.add([mediumPriorityBatchListener, 'M']) + } + const store = createStore() + store.set(a, 1) + const mockBatch = expect.objectContaining({}) + expect(highPriorityBatchListener).toHaveBeenCalledWith(mockBatch) + expect(mediumPriorityBatchListener).toHaveBeenCalledWith(mockBatch) + expect(lowPriorityBatchListener).toHaveBeenCalledWith(mockBatch) + // eslint-disable-next-line @vitest/valid-expect + ;(expect(highPriorityBatchListener) as any).toHaveBeenCalledBefore( + mediumPriorityBatchListener, + ) + // eslint-disable-next-line @vitest/valid-expect + ;(expect(mediumPriorityBatchListener) as any).toHaveBeenCalledBefore( + lowPriorityBatchListener, + ) +}) diff --git a/tests/vanilla/unstable_derive.test.tsx b/tests/vanilla/unstable_derive.test.tsx index f207d22f59a..dee2741e7c7 100644 --- a/tests/vanilla/unstable_derive.test.tsx +++ b/tests/vanilla/unstable_derive.test.tsx @@ -14,14 +14,14 @@ describe('unstable_derive for scoping atoms', () => { const store = createStore() const derivedStore = store.unstable_derive( - (getAtomState, atomRead, atomWrite, atomOnMount) => { + (getAtomState, atomRead, atomWrite, atomOnMount, atomOnInit) => { const scopedAtomStateMap = new WeakMap() return [ (atom) => { if (scopedAtoms.has(atom)) { let atomState = scopedAtomStateMap.get(atom) if (!atomState) { - atomState = { d: new Map(), p: new Set(), n: 0 } + atomState = { d: new Map(), l: new Set(), p: new Set(), n: 0 } scopedAtomStateMap.set(atom, atomState) } return atomState @@ -31,6 +31,7 @@ describe('unstable_derive for scoping atoms', () => { atomRead, atomWrite, atomOnMount, + atomOnInit, ] }, ) @@ -59,14 +60,14 @@ describe('unstable_derive for scoping atoms', () => { const store = createStore() const derivedStore = store.unstable_derive( - (getAtomState, atomRead, atomWrite, atomOnMount) => { + (getAtomState, atomRead, atomWrite, atomOnMount, atomOnInit) => { const scopedAtomStateMap = new WeakMap() return [ (atom) => { if (scopedAtoms.has(atom)) { let atomState = scopedAtomStateMap.get(atom) if (!atomState) { - atomState = { d: new Map(), p: new Set(), n: 0 } + atomState = { d: new Map(), l: new Set(), p: new Set(), n: 0 } scopedAtomStateMap.set(atom, atomState) } return atomState @@ -76,6 +77,7 @@ describe('unstable_derive for scoping atoms', () => { atomRead, atomWrite, atomOnMount, + atomOnInit, ] }, ) @@ -103,14 +105,14 @@ describe('unstable_derive for scoping atoms', () => { function makeStores() { const store = createStore() const derivedStore = store.unstable_derive( - (getAtomState, atomRead, atomWrite, atomOnMount) => { + (getAtomState, atomRead, atomWrite, atomOnMount, atomOnInit) => { const scopedAtomStateMap = new WeakMap() return [ (atom) => { if (scopedAtoms.has(atom)) { let atomState = scopedAtomStateMap.get(atom) if (!atomState) { - atomState = { d: new Map(), p: new Set(), n: 0 } + atomState = { d: new Map(), l: new Set(), p: new Set(), n: 0 } scopedAtomStateMap.set(atom, atomState) } return atomState @@ -128,6 +130,7 @@ describe('unstable_derive for scoping atoms', () => { }, atomWrite, atomOnMount, + atomOnInit, ] }, ) @@ -178,3 +181,30 @@ describe('unstable_derive for scoping atoms', () => { } }) }) + +it('should pass the correct store instance to the atom initializer', () => { + expect.assertions(2) + const baseStore = createStore() + const derivedStore = baseStore.unstable_derive( + (getAtomState, atomRead, atomWrite, atomOnMount, atomOnInit) => [ + (a, atomOnInit) => { + const atomState = getAtomState(a) + atomOnInit?.(a, atomState) + return atomState + }, + atomRead, + atomWrite, + atomOnMount, + atomOnInit, + ], + ) + const a = atom(null) + a.INTERNAL_onInit = (store) => { + expect(store).toBe(baseStore) + } + baseStore.get(a) + a.INTERNAL_onInit = (store) => { + expect(store).toBe(derivedStore) + } + derivedStore.get(a) +})