diff --git a/src/vanilla/store.ts b/src/vanilla/store.ts index 20e0a18508..785b3b0c50 100644 --- a/src/vanilla/store.ts +++ b/src/vanilla/store.ts @@ -194,8 +194,14 @@ const registerBatchAtom = ( ) => { if (!batch.D.has(atom)) { batch.D.set(atom, new Set()) - addBatchFunc(batch, 'M', () => { - atomState.m?.l.forEach((listener) => addBatchFunc(batch, 'M', listener)) + addBatchFunc(batch, 'H', () => { + for (const listener of atomState.m?.l || []) { + let priority: BatchPriority = 'M' + if ('INTERNAL_priority' in listener) { + priority = listener.INTERNAL_priority as BatchPriority + } + addBatchFunc(batch, priority, listener) + } }) } } @@ -214,9 +220,39 @@ const addBatchAtomDependent = ( const getBatchAtomDependents = (batch: Batch, atom: AnyAtom) => batch.D.get(atom) +const flushBatch = (batch: Batch) => { + let error: AnyError + let hasError = false + const call = (fn: () => void) => { + try { + fn() + } catch (e) { + if (!hasError) { + error = e + hasError = true + } + } + } + while (batch.H.size || batch.M.size || batch.L.size) { + batch.D.clear() + batch.H.forEach(call) + batch.H.clear() + batch.M.forEach(call) + batch.M.clear() + batch.L.forEach(call) + batch.L.clear() + } + if (hasError) { + throw error + } +} + // internal & unstable type type StoreArgs = readonly [ - getAtomState: (atom: Atom) => AtomState, + getAtomState: ( + atom: Atom, + atomOnInit?: (atom: AnyAtom) => void, + ) => AtomState, atomRead: ( atom: Atom, ...params: Parameters['read']> @@ -229,6 +265,7 @@ type StoreArgs = readonly [ atom: WritableAtom, setAtom: (...args: Args) => Result, ) => OnUnmount | void, + atomOnInit: (store: Store) => (atom: AnyAtom) => void, ] // for debugging purpose only @@ -248,7 +285,6 @@ type PrdStore = { ) => Result sub: (atom: AnyAtom, listener: () => void) => () => void unstable_derive: (fn: (...args: StoreArgs) => StoreArgs) => Store - unstable_onChange: (handler: OnChangeHandler) => () => void } export type Store = PrdStore | (PrdStore & DevStoreRev4) @@ -256,15 +292,12 @@ export type Store = PrdStore | (PrdStore & DevStoreRev4) export type INTERNAL_DevStoreRev4 = DevStoreRev4 export type INTERNAL_PrdStore = PrdStore -type OnChangeHandler = ( - changedAtoms: Set, - mountedAtoms: Set, - unmountedAtoms: Set, -) => void - const buildStore = ( - ...[getAtomState, atomRead, atomWrite, atomOnMount]: StoreArgs + ...[baseGetAtomState, atomRead, atomWrite, atomOnMount, atomOnInit]: StoreArgs ): Store => { + const getAtomState = (atom: Atom) => + baseGetAtomState(atom, atomOnInit(store)) + // for debugging purpose only let debugMountedAtoms: Set @@ -691,68 +724,15 @@ const buildStore = ( } const unstable_derive = (fn: (...args: StoreArgs) => StoreArgs) => - buildStore(...fn(getAtomState, atomRead, atomWrite, atomOnMount)) - - const onChangeHandlers = new Set() - - const unstable_onChange = (handler: OnChangeHandler) => { - onChangeHandlers.add(handler) - return () => { - onChangeHandlers.delete(handler) - } - } - - const flushBatch = (batch: Batch) => { - let error: AnyError - let hasError = false - const call = (fn: () => void) => { - try { - fn() - } catch (e) { - if (!hasError) { - error = e - hasError = true - } - } - } - const shouldContinue = () => ([1, 2, 3] as const).some((p) => batch[p].size) - do { - const changedAtoms = new Set(batch.D.keys()) - while (shouldContinue()) { - batch.D.clear() - batch.H.forEach(call) - batch.H.clear() - batch.M.forEach(call) - batch.M.clear() - batch.L.forEach(call) - batch.L.clear() - const moreChangedAtoms = new Set(batch.D.keys()) - for (const atom of moreChangedAtoms) { - changedAtoms.add(atom) - } - } - // Process onChange handlers after all atoms are updated - if (changedAtoms.size || batch.M.size || batch.U.size) { - const mountedAtoms = new Set(batch.M) - const unmountedAtoms = new Set(batch.U) - batch.M.clear() - batch.U.clear() - for (const handler of onChangeHandlers) { - handler(changedAtoms, mountedAtoms, unmountedAtoms) - } - } - } while (shouldContinue()) - if (hasError) { - throw error - } - } + buildStore( + ...fn(baseGetAtomState, atomRead, atomWrite, atomOnMount, atomOnInit), + ) const store: Store = { get: readAtom, set: writeAtom, sub: subscribeAtom, unstable_derive, - unstable_onChange, } if (import.meta.env?.MODE !== 'production') { const devStore: DevStoreRev4 = { @@ -792,27 +772,25 @@ const buildStore = ( export const createStore = (): Store => { const atomStateMap = new WeakMap() - const getAtomState = (atom: Atom) => { + const getAtomState: StoreArgs[0] = (atom, onInit) => { if (import.meta.env?.MODE !== 'production' && !atom) { throw new Error('Atom is undefined or null') } - let atomState = atomStateMap.get(atom) as AtomState | undefined + let atomState = atomStateMap.get(atom) as AtomState | undefined if (!atomState) { atomState = { d: new Map(), p: new Set(), n: 0 } atomStateMap.set(atom, atomState) - if (typeof atom.unstable_onInit === 'function') { - atom.unstable_onInit(store) - } + onInit?.(atom) } return atomState } - const store = buildStore( + return buildStore( getAtomState, (atom, ...params) => atom.read(...params), (atom, ...params) => atom.write(...params), (atom, ...params) => atom.onMount?.(...params), + (store) => (atom) => atom.unstable_onInit?.(store), ) - return store } let defaultStore: Store | undefined diff --git a/tests/vanilla/effect.test.ts b/tests/vanilla/effect.test.ts index a5577cf563..7f1457fb27 100644 --- a/tests/vanilla/effect.test.ts +++ b/tests/vanilla/effect.test.ts @@ -15,12 +15,13 @@ type Ref = { inProgress: number isPending: boolean deps: Set - unsub?: () => void + sub: () => () => void + epoch: number } function atomSyncEffect(effect: Effect) { const refAtom = atom( - () => ({ deps: new Set(), inProgress: 0 }) as Ref, + () => ({ deps: new Set(), inProgress: 0, epoch: 0 }) as Ref, (get, set) => { const ref = get(refAtom) if (!ref.get.peak) { @@ -36,7 +37,7 @@ function atomSyncEffect(effect: Effect) { } const recurse: Setter = (a, ...args) => { if (ref.fromCleanup) { - if (process.env.NODE_ENV !== 'production') { + if (import.meta.env?.MODE !== 'production') { console.warn('cannot recurse inside cleanup') } return undefined as any @@ -52,12 +53,13 @@ function atomSyncEffect(effect: Effect) { ref.cleanup = null ref.isPending = false ref.deps.clear() - ref.unsub?.() } }, ) refAtom.onMount = (mount) => mount() + const refreshAtom = atom(0) const internalAtom = atom((get) => { + get(refreshAtom) const ref = get(refAtom) if (!ref.get) { ref.get = ((a) => { @@ -67,36 +69,46 @@ function atomSyncEffect(effect: Effect) { } ref.deps.forEach(get) ref.isPending = true + return ++ref.epoch }) - internalAtom.unstable_onInit = (store) => { - const unsub = store.unstable_onChange(() => { - const ref = store.get(refAtom) - if (!ref.isPending || ref.inProgress > 0) { - return - } - ref.isPending = false - ref.cleanup?.() - const cleanup = effectAtom.effect(ref.get!, ref.set!) - ref.cleanup = - typeof cleanup === 'function' - ? () => { - try { - ref.fromCleanup = true - cleanup() - } finally { - ref.fromCleanup = false - } - } - : null - }) - new FinalizationRegistry(unsub).register(internalAtom, null) - } - if (process.env.NODE_ENV !== 'production') { - refAtom.debugPrivate = true - internalAtom.debugPrivate = true + const bridgeAtom = atom( + (get) => get(internalAtom), + (get, set) => { + set(refreshAtom, (v) => ++v) + return get(refAtom).sub() + }, + ) + bridgeAtom.onMount = (mount) => mount() + bridgeAtom.unstable_onInit = (store) => { + store.get(refAtom).sub = () => { + const listener = Object.assign( + () => { + const ref = store.get(refAtom) + if (!ref.isPending || ref.inProgress > 0) { + return + } + ref.isPending = false + ref.cleanup?.() + const cleanup = effectAtom.effect(ref.get!, ref.set!) + ref.cleanup = + typeof cleanup === 'function' + ? () => { + try { + ref.fromCleanup = true + cleanup() + } finally { + ref.fromCleanup = false + } + } + : null + }, + { INTERNAL_priority: 'H' }, + ) + return store.sub(internalAtom, listener) + } } const effectAtom = Object.assign( - atom((get) => void get(internalAtom)), + atom((get) => void get(bridgeAtom)), { effect }, ) return effectAtom @@ -105,14 +117,11 @@ function atomSyncEffect(effect: Effect) { it('responds to changes to atoms when subscribed', () => { const store = createStore() const a = atom(1) - a.debugLabel = 'a' const b = atom(1) - b.debugLabel = 'b' const w = atom(null, (_get, set, value: number) => { set(a, value) set(b, value) }) - w.debugLabel = 'w' const results: number[] = [] const cleanup = vi.fn() const effect = vi.fn((get: Getter) => { @@ -120,7 +129,6 @@ it('responds to changes to atoms when subscribed', () => { return cleanup }) const e = atomSyncEffect(effect) - e.debugLabel = 'e' const unsub = store.sub(e, () => {}) // mount syncEffect expect(effect).toBeCalledTimes(1) expect(results).toStrictEqual([11]) // initial values at time of effect mount @@ -145,14 +153,11 @@ it('responds to changes to atoms when subscribed', () => { it('responds to changes to atoms when mounted with get', () => { const store = createStore() const a = atom(1) - a.debugLabel = 'a' const b = atom(1) - b.debugLabel = 'b' const w = atom(null, (_get, set, value: number) => { set(a, value) set(b, value) }) - w.debugLabel = 'w' const results: number[] = [] const cleanup = vi.fn() const effect = vi.fn((get: Getter) => { @@ -160,9 +165,7 @@ it('responds to changes to atoms when mounted with get', () => { return cleanup }) const e = atomSyncEffect(effect) - e.debugLabel = 'e' const d = atom((get) => get(e)) - d.debugLabel = 'd' const unsub = store.sub(d, () => {}) // mount syncEffect expect(effect).toBeCalledTimes(1) expect(results).toStrictEqual([11]) // initial values at time of effect mount @@ -183,12 +186,10 @@ it('responds to changes to atoms when mounted with get', () => { it('sets values to atoms without causing infinite loop', () => { const store = createStore() const a = atom(1) - a.debugLabel = 'a' const effect = vi.fn((get: Getter, set: Setter) => { set(a, get(a) + 1) }) const e = atomSyncEffect(effect) - e.debugLabel = 'e' const unsub = store.sub(e, () => {}) // mount syncEffect expect(effect).toBeCalledTimes(1) expect(store.get(a)).toBe(2) // initial values at time of effect mount @@ -202,13 +203,11 @@ it('sets values to atoms without causing infinite loop', () => { it('reads the value with peak without subscribing to updates', () => { const store = createStore() const a = atom(1) - a.debugLabel = 'a' let result = 0 const effect = vi.fn((get: GetterWithPeak) => { result = get.peak(a) }) const e = atomSyncEffect(effect) - e.debugLabel = 'e' store.sub(e, () => {}) // mount syncEffect expect(effect).toBeCalledTimes(1) expect(result).toBe(1) // initial values at time of effect mount @@ -219,15 +218,13 @@ it('reads the value with peak without subscribing to updates', () => { it('supports recursion', () => { const store = createStore() const a = atom(1) - a.debugLabel = 'a' const effect = vi.fn((get: Getter, set: SetterWithRecurse) => { if (get(a) < 3) { set.recurse(a, (v) => ++v) } }) const e = atomSyncEffect(effect) - e.debugLabel = 'e' - const unsub = store.sub(e, () => {}) // mount syncEffect + store.sub(e, () => {}) // mount syncEffect expect(effect).toBeCalledTimes(3) expect(store.get(a)).toBe(3) }) diff --git a/tests/vanilla/store.test.tsx b/tests/vanilla/store.test.tsx index 5a186b0a3f..d95f0f7699 100644 --- a/tests/vanilla/store.test.tsx +++ b/tests/vanilla/store.test.tsx @@ -1005,3 +1005,65 @@ it('should process all atom listeners even if some of them throw errors', () => expect(listenerB).toHaveBeenCalledTimes(1) expect(listenerC).toHaveBeenCalledTimes(1) }) + +it('should call onInit only once per atom', () => { + const store = createStore() + const a = atom(0) + const onInit = vi.fn() + a.unstable_onInit = onInit + store.get(a) + expect(onInit).toHaveBeenCalledTimes(1) + expect(onInit).toHaveBeenCalledWith(store) + onInit.mockClear() + store.get(a) + store.set(a, 1) + const unsub = store.sub(a, () => {}) + unsub() + const b = atom((get) => get(a)) + store.get(b) + store.sub(b, () => {}) + expect(onInit).not.toHaveBeenCalled() +}) + +it('should call onInit only once per store', () => { + const a = atom(0) + const aOnInit = vi.fn() + a.unstable_onInit = aOnInit + const b = atom(0) + const bOnInit = vi.fn() + b.unstable_onInit = bOnInit + type Store = ReturnType + function testInStore(store: Store) { + store.get(a) + store.get(b) + expect(aOnInit).toHaveBeenCalledTimes(1) + expect(bOnInit).toHaveBeenCalledTimes(1) + expect(aOnInit).toHaveBeenCalledWith(store) + expect(bOnInit).toHaveBeenCalledWith(store) + aOnInit.mockClear() + bOnInit.mockClear() + return store + } + testInStore(createStore()) + const store = testInStore(createStore()) + testInStore( + store.unstable_derive( + (getAtomState, readAtom, writeAtom, atomOnMount, atomOnInit) => { + const initializedAtoms = new WeakSet() + return [ + (a, atomOnInit) => { + if (!initializedAtoms.has(a)) { + initializedAtoms.add(a) + atomOnInit?.(a) + } + return getAtomState(a, atomOnInit) + }, + readAtom, + writeAtom, + atomOnMount, + atomOnInit, + ] + }, + ), + ) +}) diff --git a/tests/vanilla/unstable_derive.test.tsx b/tests/vanilla/unstable_derive.test.tsx index f207d22f59..da3ece1081 100644 --- a/tests/vanilla/unstable_derive.test.tsx +++ b/tests/vanilla/unstable_derive.test.tsx @@ -14,7 +14,7 @@ describe('unstable_derive for scoping atoms', () => { const store = createStore() const derivedStore = store.unstable_derive( - (getAtomState, atomRead, atomWrite, atomOnMount) => { + (getAtomState, atomRead, atomWrite, atomOnMount, atomOnInit) => { const scopedAtomStateMap = new WeakMap() return [ (atom) => { @@ -26,11 +26,12 @@ describe('unstable_derive for scoping atoms', () => { } return atomState } - return getAtomState(atom) + return getAtomState(atom, atomOnInit(derivedStore)) }, atomRead, atomWrite, atomOnMount, + atomOnInit, ] }, ) @@ -59,7 +60,7 @@ describe('unstable_derive for scoping atoms', () => { const store = createStore() const derivedStore = store.unstable_derive( - (getAtomState, atomRead, atomWrite, atomOnMount) => { + (getAtomState, atomRead, atomWrite, atomOnMount, atomOnInit) => { const scopedAtomStateMap = new WeakMap() return [ (atom) => { @@ -71,11 +72,12 @@ describe('unstable_derive for scoping atoms', () => { } return atomState } - return getAtomState(atom) + return getAtomState(atom, atomOnInit(derivedStore)) }, atomRead, atomWrite, atomOnMount, + atomOnInit, ] }, ) @@ -103,7 +105,7 @@ describe('unstable_derive for scoping atoms', () => { function makeStores() { const store = createStore() const derivedStore = store.unstable_derive( - (getAtomState, atomRead, atomWrite, atomOnMount) => { + (getAtomState, atomRead, atomWrite, atomOnMount, atomOnInit) => { const scopedAtomStateMap = new WeakMap() return [ (atom) => { @@ -115,7 +117,7 @@ describe('unstable_derive for scoping atoms', () => { } return atomState } - return getAtomState(atom) + return getAtomState(atom, atomOnInit(derivedStore)) }, (a, get, options) => { const myGet: Getter = (aa) => { @@ -128,6 +130,7 @@ describe('unstable_derive for scoping atoms', () => { }, atomWrite, atomOnMount, + atomOnInit, ] }, ) @@ -178,3 +181,17 @@ describe('unstable_derive for scoping atoms', () => { } }) }) + +it('should pass the correct store instance to the atom initializer', () => { + const baseStore = createStore() + const derivedStore = baseStore.unstable_derive((...args) => args) + const a = atom(null) + a.unstable_onInit = (store) => { + expect(store).toBe(baseStore) + } + baseStore.get(a) + a.unstable_onInit = (store) => { + expect(store).toBe(derivedStore) + } + derivedStore.get(a) +})