Skip to content

Commit

Permalink
feat: fixed dataset generation
Browse files Browse the repository at this point in the history
  • Loading branch information
kirangadhave committed Oct 21, 2023
1 parent 952970e commit 9fd91f6
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 78 deletions.
252 changes: 205 additions & 47 deletions examples/test_ext_widget.ipynb

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions persist_ext/internals/data/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def get(self, name: str):

return self.__dataframe_map[name]

def set(self, name: str, data: DataFrame):
if name in self.__dataframe_map:
def set(self, name: str, data: DataFrame, override):
if not override and name in self.__dataframe_map:
raise KeyError(f"Already exists dataframe named '{name}'")

self.__dataframe_map[name] = data
Expand All @@ -31,6 +31,10 @@ def remove(self, name: str):
global_generated_record = GeneratedRecord()


def keys():
return global_generated_record.__dataframe_map.keys()


def remove_dataframe(name: str):
global_generated_record.remove(name)

Expand All @@ -39,5 +43,5 @@ def has_dataframe(name: str):
return global_generated_record.has(name)


def add_dataframe(name: str, data: DataFrame):
return global_generated_record.set(name, data)
def add_dataframe(name: str, data: DataFrame, override):
return global_generated_record.set(name, data, override)
22 changes: 22 additions & 0 deletions persist_ext/internals/data/process_generate_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from persist_ext.internals.data.idfy import ID_COLUMN
from persist_ext.internals.widgets.vegalite_chart.selection import (
SELECTED_COLUMN_BRUSH,
SELECTED_COLUMN_INTENT,
)


def process_generate_dataset(df, keep_selection_columns=False, keep_id_col=False):
df = df.copy(deep=True)

cols_to_remove = []

if not keep_id_col:
cols_to_remove.append(ID_COLUMN)

if not keep_selection_columns:
cols_to_remove.append(SELECTED_COLUMN_BRUSH)
cols_to_remove.append(SELECTED_COLUMN_INTENT)

df.drop(columns=cols_to_remove, inplace=True)

return df
41 changes: 36 additions & 5 deletions persist_ext/internals/widgets/base/body_widget_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import traittypes
from abc import ABC, abstractmethod
from copy import deepcopy
from threading import Thread


import traitlets
from persist_ext.internals.data.idfy import ID_COLUMN, idfy_dataframe
Expand Down Expand Up @@ -91,46 +93,73 @@ def _on_interaction_change(self, change):
self._interaction_change(interactions)

def _interaction_change(self, interactions):
# Set is_apply flag to true
self.is_applying = True

# copy all vars to update
copied_var_tuple = self._copy_vars()

# Apply interactions
copied_var_tuple = self._apply_interactions(interactions, *copied_var_tuple)

# Replace traitlets with copies by holding sync
with self.hold_sync():
self._update_copies(*copied_var_tuple)
self.is_applying = False

# Loop over interaction and update the copies
def _apply_interactions(self, interactions, *copied_var_tuple):
# Loop over interaction and update the copies

# Set cache hit status to None
last_cache_hit_id = None
threads = []

# loop over interactions
for interaction in interactions:
id = interaction["id"]
# check if interaction is cached
if id in self._cached_apply_record:
# if yes then set the last_cache_hit_id
last_cache_hit_id = id
else:
# last interaction was cached
if last_cache_hit_id is not None:
# Load the cached values
copied_var_tuple = self._from_cache(
*self._cached_apply_record[last_cache_hit_id]
)
# reset last_cache_hit_id
last_cache_hit_id = None

# Get type and apply
_type = interaction["type"]
fn_name = f"_apply_{_type}"
if hasattr(self, fn_name):
fn = getattr(self, fn_name)
copied_var_tuple = fn(interaction, *copied_var_tuple)
self._cached_apply_record[id] = self._to_cache(*copied_var_tuple)

# Update the cache with interaction id in thread
def __update(id, vars_to_copy):
self._cached_apply_record[id] = vars_to_copy
print("Fin", id)

print("Start", id)
thread = Thread(
target=__update, args=(id, self._to_cache(*copied_var_tuple))
)
thread.start()
threads.append(thread)
else:
raise ValueError(f"Method {fn_name} not implemented")

# If last hit is set, retrieve it and return
if last_cache_hit_id is not None:
copied_var_tuple = self._from_cache(
*self._cached_apply_record[last_cache_hit_id]
)

for t in threads:
t.join()

return copied_var_tuple

@abstractmethod
Expand All @@ -144,6 +173,10 @@ def _from_cache(self, *args):
def _default_cache_to_from(self, *args):
return deepcopy(args)

@abstractmethod
def _get_data(self, *args):
pass

@abstractmethod
def _copy_vars(self):
pass
Expand Down Expand Up @@ -223,8 +256,6 @@ def _append_annotation(val):
selected(data), ANNOTATE_COLUMN_NAME
].apply(_append_annotation)

print(data[ANNOTATE_COLUMN_NAME].unique())

return data

# rename
Expand Down
69 changes: 51 additions & 18 deletions persist_ext/internals/widgets/header/header_widget.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from pandas import DataFrame
import traitlets
from traitlets.traitlets import Unicode
from persist_ext.internals.data.generated import (
add_dataframe,
has_dataframe,
remove_dataframe,
)
from persist_ext.internals.data.process_generate_dataset import process_generate_dataset

from persist_ext.internals.widgets.base.widget_with_trrack import WidgetWithTrrack

Expand All @@ -15,14 +15,21 @@ class HeaderWidget(WidgetWithTrrack):

cell_id = Unicode("12").tag(sync=True)

def __init__(self):
def __init__(self, body_widget):
super(HeaderWidget, self).__init__(widget_key=self.__widget_key)
self._body_widget = body_widget

@traitlets.observe("interactions")
def _on_interaction_change_df_gen(self, change):
for df_name, df_record in self.generated_dataframe_record.items():
if df_record["dynamic"] is True and not has_dataframe(df_name):
self.create(df_record)
# Update dynamic dataframes when current changes
for df_record in self.generated_dataframe_record.values():
# Filter in dynamic dataframes
if df_record["dynamic"] is True:
# create
try:
self._create(df_record)
except Exception as ex:
raise ex

@traitlets.observe("generated_dataframe_record")
def _on_generated_dataframe_record(self, change):
Expand All @@ -32,18 +39,21 @@ def _on_generated_dataframe_record(self, change):
old = change["old"]
new = change["new"]

# dataframes to add
added_df_names = set(new) - set(old)
# dataframes to remove
removed_df_names = set(old) - set(new)

# remove dfs
for removed in removed_df_names:
remove_dataframe(removed)

# Try to generate dataframe
for df_name, df_record in new.items():
try:
if not has_dataframe(df_name):
df = self._create(df_record)
add_dataframe(df_name, df)
self._create(df_record)

# Notify creation success
if df_name in added_df_names:
msg.append({"type": "df-created", "name": df_name})
except Exception as e:
Expand All @@ -52,19 +62,42 @@ def _on_generated_dataframe_record(self, change):
self.send({"msg": msg, "error": err})

def _create(self, record):
# check if dynamic df
is_dynamic = record["dynamic"]

if is_dynamic:
return self.data.copy(deep=True)

data = self._persistent_data.copy(deep=True)

interactions = record["interactions"]
# Set initial data as none
data = None

for interaction in interactions:
print(interaction)

return data
if is_dynamic:
# if dynamic then associate current data
data = self._body_widget.data
else:
if has_dataframe(record["df_name"]):
return
# Generate data for list of interactions present in the record

interactions = record["interactions"]

# get last interactions
last_interaction_id = interactions[-1]["id"]

# if last interaction has no dataframe cached, generate and cache it.
if last_interaction_id not in self._body_widget._cached_apply_record:
self._body_widget._apply_interactions(
interactions, self.data.copy(deep=True)
)

# assign data from cache
data = self._body_widget._get_data(
*self._body_widget._from_cache(
*self._body_widget._cached_apply_record[last_interaction_id]
)
)

# Assign data to df_name
add_dataframe(
record["df_name"], process_generate_dataset(data), override=is_dynamic
)

@traitlets.observe("trrack")
def _on_trrack(self, _change):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class OutputWithTrrackWidget(VBox):
def __init__(self, body_widget, data):
self._update_body(body_widget)

self.header_widget = HeaderWidget()
self.header_widget = HeaderWidget(body_widget=body_widget)
self.trrack_widget = TrrackWidget()
self.intent_widget = IntentWidget()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,17 @@ def _update_copies(self, data, chart):
def _to_cache(self, data, chart):
data = data.to_parquet(compression="brotli")
chart = copy_altair_chart(chart)
chart.data = Undefined
return data, chart

def _from_cache(self, data, chart):
data = pd.read_parquet(BytesIO(data))
chart = copy_altair_chart(chart)
return data, chart

def _get_data(self, data, *args, **kwargs):
return data

# Interactions
def _apply_create(self, _interaction, data, chart):
return data, chart

Expand Down
2 changes: 0 additions & 2 deletions src/widgets/vegalite/Vegalite.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ function Vegalite({ cell }: Props) {
const [wait] = useModelState<number>('debounce_wait');
const [isApplying] = useModelState<boolean>('is_applying');

console.log({ isApplying });

const vegaView = useMemo(() => {
return new VegaView();
}, []); // Initialize VegaView wrapper
Expand Down

0 comments on commit 9fd91f6

Please sign in to comment.