Skip to content

Commit

Permalink
defer recompute dependents if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
dmaskasky committed Dec 13, 2024
1 parent 51218f1 commit e0e0931
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 131 deletions.
258 changes: 154 additions & 104 deletions src/vanilla/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ type AtomState<Value = AnyValue> = {
v?: Value
/** Atom error */
e?: AnyError
/** Indicates whether the atom value is has been changed */
x?: boolean
}

const isAtomStateInitialized = <Value>(atomState: AtomState<Value>) =>
Expand Down Expand Up @@ -167,7 +169,14 @@ type Pending = readonly [
functions: Set<() => void>,
]

const createPending = (): Pending => [new Map(), new Map(), new Set()]
const createPending = (): Pending => [
/** dependents */
new Map(),
/** atomStates */
new Map(),
/** functions */
new Set(),
]

const addPendingAtom = (
pending: Pending,
Expand Down Expand Up @@ -198,33 +207,6 @@ const addPendingFunction = (pending: Pending, fn: () => void) => {
pending[2].add(fn)
}

const flushPending = (pending: Pending) => {
let error: AnyError
let hasError = false
const call = (fn: () => void) => {
try {
fn()
} catch (e) {
if (!hasError) {
error = e
hasError = true
}
}
}
while (pending[1].size || pending[2].size) {
pending[0].clear()
const atomStates = new Set(pending[1].values())
pending[1].clear()
const functions = new Set(pending[2])
pending[2].clear()
atomStates.forEach((atomState) => atomState.m?.l.forEach(call))
functions.forEach(call)
}
if (hasError) {
throw error
}
}

// internal & unstable type
type StoreArgs = readonly [
getAtomState: <Value>(atom: Atom<Value>) => AtomState<Value>,
Expand Down Expand Up @@ -276,6 +258,33 @@ const buildStore = (
debugMountedAtoms = new Set()
}

const flushPending = (pending: Pending) => {
let error: AnyError
let hasError = false
const call = (fn: () => void) => {
try {
fn()
} catch (e) {
if (!hasError) {
error = e
hasError = true
}
}
}
while (pending[0].size || pending[1].size || pending[2].size) {
recomputeDependents(pending, new Set(pending[0].keys()))
const atomStates = new Set(pending[1].values())
pending[1].clear()
const functions = new Set(pending[2])
pending[2].clear()
atomStates.forEach((atomState) => atomState.m?.l.forEach(call))
functions.forEach(call)
}
if (hasError) {
throw error
}
}

const setAtomStateValueOrPromise = (
atom: AnyAtom,
atomState: AtomState,
Expand Down Expand Up @@ -306,15 +315,14 @@ const buildStore = (
const readAtomState = <Value>(
pending: Pending | undefined,
atom: Atom<Value>,
dirtyAtoms?: Set<AnyAtom>,
): AtomState<Value> => {
const atomState = getAtomState(atom)
// See if we can skip recomputing this atom.
if (isAtomStateInitialized(atomState)) {
// If the atom is mounted, we can use cached atom state.
// because it should have been updated by dependencies.
// We can't use the cache if the atom is dirty.
if (atomState.m && !dirtyAtoms?.has(atom)) {
if (atomState.m && !atomState.x) {
return atomState
}
// Otherwise, check if the dependencies have changed.
Expand All @@ -324,7 +332,7 @@ const buildStore = (
([a, n]) =>
// Recursively, read the atom state of the dependency, and
// check if the atom epoch number is unchanged
readAtomState(pending, a, dirtyAtoms).n === n,
readAtomState(pending, a).n === n,
)
) {
return atomState
Expand All @@ -347,7 +355,7 @@ const buildStore = (
return returnAtomValue(aState)
}
// a !== atom
const aState = readAtomState(pending, a, dirtyAtoms)
const aState = readAtomState(pending, a)
try {
return returnAtomValue(aState)
} finally {
Expand Down Expand Up @@ -418,108 +426,150 @@ const buildStore = (
const readAtom = <Value>(atom: Atom<Value>): Value =>
returnAtomValue(readAtomState(undefined, atom))

const getMountedOrPendingDependents = <Value>(
const markRecomputePending = (
pending: Pending,
atom: Atom<Value>,
atomState: AtomState<Value>,
): Map<AnyAtom, AtomState> => {
const dependents = new Map<AnyAtom, AtomState>()
for (const a of atomState.m?.t || []) {
atom: AnyAtom,
atomState: AtomState,
) => {
addPendingAtom(pending, atom, atomState)
if (isPendingRecompute(atom)) {
return
}
const dependents = getAllDependents(pending, [atom])
for (const [dependent] of dependents) {
getAtomState(dependent).x = true
}
}

const markRecomputeComplete = (
pending: Pending,
atom: AnyAtom,
atomState: AtomState,
) => {
atomState.x = false
pending[0].delete(atom)
}

const isPendingRecompute = (atom: AnyAtom) => getAtomState(atom).x

const getMountedDependents = (
pending: Pending,
a: AnyAtom,
aState: AtomState,
) => {
return new Set<AnyAtom>(
[
...(aState.m?.t || []),
...aState.p,
...(getPendingDependents(pending, a) || []),
].filter((a) => getAtomState(a).m),
)
}

/** @returns map of all dependents or dependencies (deep) of the root atoms */
const getDeep = (
/** function to get immediate dependents or dependencies of the atom */
getDeps: (a: AnyAtom, aState: AtomState) => Iterable<AnyAtom>,
rootAtoms: Iterable<AnyAtom>,
) => {
const visited = new Map<AnyAtom, Set<AnyAtom>>()
const stack: AnyAtom[] = Array.from(rootAtoms)
while (stack.length > 0) {
const a = stack.pop()!
const aState = getAtomState(a)
if (aState.m) {
dependents.set(a, aState)
if (visited.has(a)) {
continue
}
const deps = new Set(getDeps(a, aState))
visited.set(a, deps)
for (const d of deps) {
if (!visited.has(d)) {
stack.push(d)
}
}
}
for (const atomWithPendingPromise of atomState.p) {
dependents.set(
atomWithPendingPromise,
getAtomState(atomWithPendingPromise),
)
}
getPendingDependents(pending, atom)?.forEach((dependent) => {
dependents.set(dependent, getAtomState(dependent))
})
return dependents
return visited
}

const getAllDependents = (pending: Pending, atoms: Iterable<AnyAtom>) =>
getDeep((a, aState) => getMountedDependents(pending, a, aState), atoms)

// This is a topological sort via depth-first search, slightly modified from
// what's described here for simplicity and performance reasons:
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
function getSortedDependents(
const getSortedDependents = (
pending: Pending,
rootAtom: AnyAtom,
rootAtomState: AtomState,
): [[AnyAtom, AtomState, number][], Set<AnyAtom>] {
const sorted: [atom: AnyAtom, atomState: AtomState, epochNumber: number][] =
[]
rootAtoms: Iterable<AnyAtom>,
) => {
const atomMap = getAllDependents(pending, rootAtoms)
const sorted: AnyAtom[] = []
const visiting = new Set<AnyAtom>()
const visited = new Set<AnyAtom>()
// Visit the root atom. This is the only atom in the dependency graph
// Visit the root atoms. These are the only atoms in the dependency graph
// without incoming edges, which is one reason we can simplify the algorithm
const stack: [a: AnyAtom, aState: AtomState][] = [[rootAtom, rootAtomState]]
const stack: [a: AnyAtom, dependents: Set<AnyAtom>][] = []
for (const a of rootAtoms) {
if (atomMap.has(a)) {
stack.push([a, atomMap.get(a)!])
}
}
while (stack.length > 0) {
const [a, aState] = stack[stack.length - 1]!
const [a, dependents] = stack[stack.length - 1]!
if (visited.has(a)) {
// All dependents have been processed, now process this atom
stack.pop()
continue
}
if (visiting.has(a)) {
// The algorithm calls for pushing onto the front of the list. For
// performance, we will simply push onto the end, and then will iterate in
// reverse order later.
sorted.push([a, aState, aState.n])
// The algorithm calls for pushing onto the front of the list.
// For performance we push on the end, and will reverse the order later.
sorted.push(a)
// Atom has been visited but not yet processed
visited.add(a)
stack.pop()
continue
}
visiting.add(a)
// Push unvisited dependents onto the stack
for (const [d, s] of getMountedOrPendingDependents(pending, a, aState)) {
if (a !== d && !visiting.has(d)) {
stack.push([d, s])
for (const d of dependents) {
if (a !== d && !visiting.has(d) && atomMap.has(d)) {
stack.push([d, atomMap.get(d)!])
}
}
}
return [sorted, visited]
return sorted.reverse()
}

const recomputeDependents = <Value>(
pending: Pending,
atom: Atom<Value>,
atomState: AtomState<Value>,
) => {
// Step 1: traverse the dependency graph to build the topsorted atom list
// We don't bother to check for cycles, which simplifies the algorithm.
const [topsortedAtoms, markedAtoms] = getSortedDependents(
pending,
atom,
atomState,
)

// Step 2: use the topsorted atom list to recompute all affected atoms
// Track what's changed, so that we can short circuit when possible
const changedAtoms = new Set<AnyAtom>([atom])
for (let i = topsortedAtoms.length - 1; i >= 0; --i) {
const [a, aState, prevEpochNumber] = topsortedAtoms[i]!
let hasChangedDeps = false
for (const dep of aState.d.keys()) {
if (dep !== a && changedAtoms.has(dep)) {
hasChangedDeps = true
break
}
}
if (hasChangedDeps) {
readAtomState(pending, a, markedAtoms)
const recomputeDependents = (pending: Pending, rootAtoms: Set<AnyAtom>) => {
if (rootAtoms.size === 0) {
return
}
const hasChangedDeps = (aState: AtomState) =>
Array.from(aState.d.keys()).some((d) => rootAtoms.has(d))
// traverse the dependency graph to build the topsorted atom list
for (const a of getSortedDependents(pending, rootAtoms)) {
// use the topsorted atom list to recompute all affected atoms
// Track what's changed, so that we can short circuit when possible
const aState = getAtomState(a)
const prevEpochNumber = aState.n
if (isPendingRecompute(a) || hasChangedDeps(aState)) {
readAtomState(pending, a)
mountDependencies(pending, a, aState)
if (prevEpochNumber !== aState.n) {
addPendingAtom(pending, a, aState)
changedAtoms.add(a)
markRecomputePending(pending, a, aState)
}
}
markedAtoms.delete(a)
markRecomputeComplete(pending, a, aState)
}
}

const recomputeDependencies = (pending: Pending, a: AnyAtom) => {
if (!isPendingRecompute(a)) {
return
}
const getDependencies = (_: unknown, aState: AtomState) => aState.d.keys()
const dependencies = Array.from(getDeep(getDependencies, [a]).keys())
const dirtyDependencies = new Set(dependencies.filter(isPendingRecompute))
recomputeDependents(pending, dirtyDependencies)
}

const writeAtomState = <Value, Args extends unknown[], Result>(
Expand All @@ -528,8 +578,10 @@ const buildStore = (
...args: Args
): Result => {
let isSync = true
const getter: Getter = <V>(a: Atom<V>) =>
returnAtomValue(readAtomState(pending, a))
const getter: Getter = <V>(a: Atom<V>) => {
recomputeDependencies(pending, atom)
return returnAtomValue(readAtomState(pending, a))
}
const setter: Setter = <V, As extends unknown[], R>(
a: WritableAtom<V, As, R>,
...args: As
Expand All @@ -546,8 +598,7 @@ const buildStore = (
setAtomStateValueOrPromise(a, aState, v)
mountDependencies(pending, a, aState)
if (prevEpochNumber !== aState.n) {
addPendingAtom(pending, a, aState)
recomputeDependents(pending, a, aState)
markRecomputePending(pending, a, aState)
}
return undefined as R
} else {
Expand Down Expand Up @@ -732,8 +783,7 @@ const buildStore = (
setAtomStateValueOrPromise(atom, atomState, value)
mountDependencies(pending, atom, atomState)
if (prevEpochNumber !== atomState.n) {
addPendingAtom(pending, atom, atomState)
recomputeDependents(pending, atom, atomState)
markRecomputePending(pending, atom, atomState)
}
}
}
Expand Down
Loading

0 comments on commit e0e0931

Please sign in to comment.