diff --git a/src/vanilla/store.ts b/src/vanilla/store.ts index 422c4ee99f..7fbe99ea42 100644 --- a/src/vanilla/store.ts +++ b/src/vanilla/store.ts @@ -165,9 +165,32 @@ type Pending = readonly [ dependents: Map>, atomStates: Map, functions: Set<() => void>, + // set of dependents of the dirtied atoms pending recompute + // `set`: + // when an atom is dirtied, its dependents (deep) are added to this set + // atoms are removed from this set when they are recomputed + // `get`: + // when an atom is read by `get` if it is in the recompute pending set, + // it and it's dependencies are recomputed in place + // all remaining dependencies are recomputed in flush pending + pendingRecompute: [dependentMap: DependentMap, changedAtoms: Set], ] -const createPending = (): Pending => [new Map(), new Map(), new Set()] +type DependentMap = Map< + AnyAtom, + [dependents: Set, AtomState: AtomState, epoch: number] +> + +const createPending = (): Pending => [ + /** dependents */ + new Map(), + /** atomStates */ + new Map(), + /** functions */ + new Set(), + /** pendingRecompute */ + [new Map(), new Set()], +] const addPendingAtom = ( pending: Pending, @@ -198,33 +221,6 @@ const addPendingFunction = (pending: Pending, fn: () => void) => { pending[2].add(fn) } -const flushPending = (pending: Pending) => { - let error: AnyError - let hasError = false - const call = (fn: () => void) => { - try { - fn() - } catch (e) { - if (!hasError) { - error = e - hasError = true - } - } - } - while (pending[1].size || pending[2].size) { - pending[0].clear() - const atomStates = new Set(pending[1].values()) - pending[1].clear() - const functions = new Set(pending[2]) - pending[2].clear() - atomStates.forEach((atomState) => atomState.m?.l.forEach(call)) - functions.forEach(call) - } - if (hasError) { - throw error - } -} - // internal & unstable type type StoreArgs = readonly [ getAtomState: (atom: Atom) => AtomState, @@ -276,6 +272,58 @@ const buildStore = ( debugMountedAtoms = new Set() } + const flushPending = (pending: Pending) => { + let error: AnyError + let hasError = false + const call = (fn: () => void) => { + try { + fn() + } catch (e) { + if (!hasError) { + error = e + hasError = true + } + } + } + if (pending[3]?.[0].size) { + recomputeDependents(pending, pending[3][0], pending[3][1]) + pending[3][0].clear() + pending[3][1].clear() + } + while (pending[1].size || pending[2].size) { + pending[0].clear() + const atomStates = new Set(pending[1].values()) + pending[1].clear() + const functions = new Set(pending[2]) + pending[2].clear() + atomStates.forEach((atomState) => atomState.m?.l.forEach(call)) + functions.forEach(call) + } + if (hasError) { + throw error + } + } + + /** + * adds the atom and its dependents to the recompute pending list + */ + const addPendingRecompute = ( + pending: Pending, + atom: AnyAtom, + atomState: AtomState, + ) => { + const dependents = getAllDependents(pending, atom, atomState) + for (const [dependent, entry] of dependents.entries()) { + pending[3][0].set(dependent, entry) + } + pending[3][1].add(atom) + } + + const removePendingRecompute = (pending: Pending, dependent: AnyAtom) => { + pending[3][0].delete(dependent) + pending[3][1].delete(dependent) + } + const setAtomStateValueOrPromise = ( atom: AnyAtom, atomState: AtomState, @@ -306,7 +354,7 @@ const buildStore = ( const readAtomState = ( pending: Pending | undefined, atom: Atom, - dirtyAtoms?: Set, + dirtyAtoms?: { has: (atom: AnyAtom) => boolean }, ): AtomState => { const atomState = getAtomState(atom) // See if we can skip recomputing this atom. @@ -439,23 +487,47 @@ const buildStore = ( return dependents } + function getAllDependents( + pending: Pending, + atom: AnyAtom, + atomState: AtomState, + ): DependentMap { + const visited: DependentMap = new Map() + const stack: [AnyAtom, AtomState][] = [[atom, atomState]] + while (stack.length > 0) { + const [a, aState] = stack.pop()! + if (visited.has(a)) { + continue + } + const dependents = getDependents(pending, a, aState) + visited.set(a, [new Set(dependents.keys()), aState, aState.n]) + for (const [d, dState] of dependents) { + if (!visited.has(d)) { + stack.push([d, dState]) + } + } + } + return visited + } + // This is a topological sort via depth-first search, slightly modified from // what's described here for simplicity and performance reasons: // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search function getSortedDependents( - pending: Pending, - rootAtom: AnyAtom, - rootAtomState: AtomState, - ): [[AnyAtom, AtomState, number][], Set] { - const sorted: [atom: AnyAtom, atomState: AtomState, epochNumber: number][] = - [] + dependents: Map, ...unknown[]]>, + changedAtoms: Set, + ): Iterable { + const sorted: AnyAtom[] = [] const visiting = new Set() const visited = new Set() // Visit the root atom. This is the only atom in the dependency graph // without incoming edges, which is one reason we can simplify the algorithm - const stack: [a: AnyAtom, aState: AtomState][] = [[rootAtom, rootAtomState]] + const stack: [a: AnyAtom, depSet: Set][] = Array.from( + changedAtoms, + (a) => [a, dependents.get(a)![0]], + ) while (stack.length > 0) { - const [a, aState] = stack[stack.length - 1]! + const [a, depSet] = stack[stack.length - 1]! if (visited.has(a)) { // All dependents have been processed, now process this atom stack.pop() @@ -465,7 +537,7 @@ const buildStore = ( // The algorithm calls for pushing onto the front of the list. For // performance, we will simply push onto the end, and then will iterate in // reverse order later. - sorted.push([a, aState, aState.n]) + sorted.push(a) // Atom has been visited but not yet processed visited.add(a) stack.pop() @@ -473,49 +545,46 @@ const buildStore = ( } visiting.add(a) // Push unvisited dependents onto the stack - for (const [d, s] of getDependents(pending, a, aState)) { - if (a !== d && !visiting.has(d)) { - stack.push([d, s]) + for (const d of depSet) { + if (a !== d && !visiting.has(d) && dependents.has(d)) { + stack.push([d, dependents.get(d)![0]]) } } } - return [sorted, visited] + function* reverse(items: ReadonlyArray): Generator { + for (let i = items.length - 1; i >= 0; i--) { + yield items[i]! + } + } + + return reverse(sorted) } const recomputeDependents = ( pending: Pending, - atom: Atom, - atomState: AtomState, + dependentMap: DependentMap, + changedAtoms = new Set(), ) => { - // Step 1: traverse the dependency graph to build the topsorted atom list - // We don't bother to check for cycles, which simplifies the algorithm. - const [topsortedAtoms, markedAtoms] = getSortedDependents( - pending, - atom, - atomState, - ) - - // Step 2: use the topsorted atom list to recompute all affected atoms - // Track what's changed, so that we can short circuit when possible - const changedAtoms = new Set([atom]) - for (let i = topsortedAtoms.length - 1; i >= 0; --i) { - const [a, aState, prevEpochNumber] = topsortedAtoms[i]! - let hasChangedDeps = false - for (const dep of aState.d.keys()) { - if (dep !== a && changedAtoms.has(dep)) { - hasChangedDeps = true - break - } - } - if (hasChangedDeps) { - readAtomState(pending, a, markedAtoms) + const hasChangedDeps = (dependents: Set, a: AnyAtom) => { + return Array.from(dependents).some( + (dep) => dep !== a && changedAtoms.has(dep), + ) + } + // traverse the dependency graph to build the topsorted atom list + for (const a of getSortedDependents(dependentMap, changedAtoms)) { + // use the topsorted atom list to recompute all affected atoms + // Track what's changed, so that we can short circuit when possible + const [dependents, aState, prevEpochNumber] = dependentMap.get(a)! + if (hasChangedDeps(dependents, a)) { + readAtomState(pending, a, dependentMap) mountDependencies(pending, a, aState) if (prevEpochNumber !== aState.n) { addPendingAtom(pending, a, aState) changedAtoms.add(a) } } - markedAtoms.delete(a) + dependentMap.delete(a) + removePendingRecompute(pending, a) } } @@ -525,8 +594,17 @@ const buildStore = ( ...args: Args ): Result => { let isSync = true - const getter: Getter = (a: Atom) => - returnAtomValue(readAtomState(pending, a)) + const getter: Getter = (a: Atom) => { + /* + Check if the atom or its dependencies (deep) are in the set of recompute pending. + */ + const pendingRecompute = pending[3][0] + if (pendingRecompute.has(a)) { + const dependents = getAllDependents(pending, a, getAtomState(a)) + recomputeDependents(pending, dependents, new Set([a])) + } + return returnAtomValue(readAtomState(pending, a)) + } const setter: Setter = ( a: WritableAtom, ...args: As @@ -544,7 +622,11 @@ const buildStore = ( mountDependencies(pending, a, aState) if (prevEpochNumber !== aState.n) { addPendingAtom(pending, a, aState) - recomputeDependents(pending, a, aState) + /* + Add the atoms that depend on `a` to the recompute pending list. + pending[3][0] is the set of dependents of the dirty atom. + */ + addPendingRecompute(pending, a, aState) } return undefined as R } else { @@ -730,7 +812,9 @@ const buildStore = ( mountDependencies(pending, atom, atomState) if (prevEpochNumber !== atomState.n) { addPendingAtom(pending, atom, atomState) - recomputeDependents(pending, atom, atomState) + if (!pending[3][0].has(atom)) { + addPendingRecompute(pending, atom, atomState) + } } } }