Skip to content

Cache invalidation for XProf #1438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ http_archive(
"//third_party:xla.patch",
"//third_party:xla_add_grpc_cares_darwin_arm64_support.patch",
],
sha256 = "26e457507da92af216814ecf334e0bf8dc9bf49eb6b812ca58c482696390098a",
strip_prefix = "xla-137cad68a7f2c5f4b52ad18acf4acf6056dde1f8",
sha256 = "93f025c617919f0adce7cd7ecd241696a0c85b6c7541b8a7057769cedb2a4c7a",
strip_prefix = "xla-63fe4109a88019662f0b173cbcd8a4c075aceb47",
urls = [
"https://github.com/openxla/xla/archive/137cad68a7f2c5f4b52ad18acf4acf6056dde1f8.zip",
"https://github.com/openxla/xla/archive/63fe4109a88019662f0b173cbcd8a4c075aceb47.zip",
],
)

Expand Down
30 changes: 18 additions & 12 deletions plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,29 +116,32 @@ def xspace_to_tool_data(
content_type = 'application/json'
# tqx: gViz output format
tqx = params.get('tqx', '')
options = {}
options['use_saved_result'] = params.get('use_saved_result', True)
if tool == 'trace_viewer':
# Trace viewer handles one host at a time.
assert len(xspace_paths) == 1
raw_data, success = xspace_wrapper_func(xspace_paths, tool)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = process_raw_trace(raw_data)
elif tool == 'trace_viewer@':
# Streaming trace viewer handles one host at a time.
assert len(xspace_paths) == 1
options = params.get('trace_viewer_options', {})
options['use_saved_result'] = params.get('use_saved_result', True)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'overview_page':
raw_data, success = xspace_wrapper_func(xspace_paths, tool)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = overview_page_proto_to_gviz.to_json(raw_data)
elif tool == 'input_pipeline_analyzer':
raw_data, success = xspace_wrapper_func(xspace_paths, tool)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = input_pipeline_proto_to_gviz.to_json(raw_data)
elif tool == 'framework_op_stats':
raw_data, success = xspace_wrapper_func(xspace_paths, tool)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
if tqx == 'out:csv':
data = tf_stats_proto_to_gviz.to_csv(raw_data)
Expand All @@ -147,14 +150,16 @@ def xspace_to_tool_data(
# Try legacy tool name: Handle backward compatibility with lower TF version
else:
legacy_tool = 'tensorflow_stats'
raw_data, success = xspace_wrapper_func(xspace_paths, legacy_tool)
raw_data, success = xspace_wrapper_func(
xspace_paths, legacy_tool, options
)
if success:
if tqx == 'out:csv':
data = tf_stats_proto_to_gviz.to_csv(raw_data)
else:
data = tf_stats_proto_to_gviz.to_json(raw_data)
elif tool == 'kernel_stats':
raw_data, success = xspace_wrapper_func(xspace_paths, tool)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
if tqx == 'out:csv;':
data = kernel_stats_proto_to_gviz.to_csv(raw_data)
Expand All @@ -163,29 +168,30 @@ def xspace_to_tool_data(
elif tool == 'memory_profile':
# Memory profile handles one host at a time.
assert len(xspace_paths) == 1
raw_data, success = xspace_wrapper_func(xspace_paths, tool)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'pod_viewer':
raw_data, success = xspace_wrapper_func(xspace_paths, tool)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'op_profile':
raw_data, success = xspace_wrapper_func(xspace_paths, tool)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
elif tool == 'hlo_stats':
json_data, success = xspace_wrapper_func(xspace_paths, tool)
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = json_data
elif tool == 'roofline_model':
json_data, success = xspace_wrapper_func(xspace_paths, tool)
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = json_data
elif tool == 'graph_viewer':
download_hlo_types = ['pb', 'pbtxt', 'json', 'short_txt', 'long_txt']
graph_html_type = 'graph'
options = params.get('graph_viewer_options', {})
options['use_saved_result'] = params.get('use_saved_result', True)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = raw_data
Expand Down Expand Up @@ -221,7 +227,7 @@ def xspace_to_tool_data(
if success:
data = dcn_collective_stats_proto_to_gviz.to_json(raw_data)
elif tool == 'inference_profile':
raw_data, success = xspace_wrapper_func(xspace_paths, tool)
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
if success:
data = inference_stats_proto_to_gviz.to_json(raw_data)
else:
Expand Down
36 changes: 34 additions & 2 deletions plugin/tensorboard_plugin_profile/profile_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import six
from werkzeug import wrappers

from tensorboard_plugin_profile import version
from tensorboard_plugin_profile.convert import raw_to_tool_data as convert
from tensorboard_plugin_profile.standalone.tensorboard_shim import base_plugin
from tensorboard_plugin_profile.standalone.tensorboard_shim import context as tb_context
Expand Down Expand Up @@ -73,6 +74,7 @@
HLO_MODULE_LIST_ROUTE = '/module_list'
CAPTURE_ROUTE = '/capture_profile'
LOCAL_ROUTE = '/local'
CACHE_VERSION_FILE = 'cache_version.txt'

# Suffixes of "^, #, @" symbols represent different input data formats for the
# same tool.
Expand Down Expand Up @@ -619,15 +621,34 @@ def data_impl(
host = request.args.get('host')
module_name = request.args.get('module_name')
tqx = request.args.get('tqx')
use_saved_result_str = request.args.get('use_saved_result', 'true')
use_saved_result = use_saved_result_str.lower() != 'false'
run_dir = self._run_dir(run)

# Check if the cache file exists and if the version is the same as the
# current version. If not, set use_saved_result to False.
try:
if epath.Path(os.path.join(run_dir, CACHE_VERSION_FILE)).exists():
with epath.Path(os.path.join(run_dir, CACHE_VERSION_FILE)).open(
'r'
) as f:
cache_version = f.read().strip()
if cache_version != version.__version__:
use_saved_result = False
else:
use_saved_result = False
except OSError as e:
logger.warning('Cannot read cache version file: %s', e)

graph_viewer_options = self._get_graph_viewer_options(request)
# Host param is used by HLO tools to identify the module.
params = {
'graph_viewer_options': graph_viewer_options,
'tqx': tqx,
'host': host,
'module_name': module_name
'module_name': module_name,
'use_saved_result': use_saved_result,
}
run_dir = self._run_dir(run)
content_type = 'application/json'

if tool not in TOOLS and not use_xplane(tool):
Expand Down Expand Up @@ -681,6 +702,17 @@ def data_impl(
except FileNotFoundError as e:
logger.warning('XPlane convert to tool data failed as %s', e)
raise e

# Write cache version file if use_saved_result is False.
if not use_saved_result:
try:
with epath.Path(os.path.join(run_dir, CACHE_VERSION_FILE)).open(
'w'
) as f:
f.write(version.__version__)
except OSError as e:
logger.warning('Cannot write cache version file: %s', e)

return data, content_type, content_encoding

logger.info('%s does not use xplane', tool)
Expand Down
74 changes: 72 additions & 2 deletions plugin/tensorboard_plugin_profile/profile_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from tensorboard_plugin_profile import profile_plugin
from tensorboard_plugin_profile import profile_plugin_test_utils as utils
from tensorboard_plugin_profile import version
from tensorboard_plugin_profile.protobuf import trace_events_pb2
from tensorboard_plugin_profile.standalone.tensorboard_shim import plugin_asset_util
from tensorboard_plugin_profile.standalone.tensorboard_shim import plugin_event_multiplexer
Expand Down Expand Up @@ -115,8 +116,12 @@ def setUp(self):
def get_temp_dir(self):
"""Return a temporary directory for tests to use."""
if not self._temp_dir:
if os.environ.get('TEST_TMPDIR'):
temp_dir = tempfile.mkdtemp(prefix=os.environ['TEST_TMPDIR'])
# If the test is running on Forge, use the TEST_UNDECLARED_OUTPUTS_DIR
# environment variable to store the temporary directory for Sponge.
if os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'):
temp_dir = tempfile.mkdtemp(
dir=os.environ['TEST_UNDECLARED_OUTPUTS_DIR']
)
else:
frame = inspect.stack()[-1]
filename = frame.filename
Expand Down Expand Up @@ -287,6 +292,71 @@ def testData(self):
self.plugin.data_impl(utils.make_data_request(
run='a', tool='trace_viewer', host=''))

def testDataWithCache(self):
generate_testdata(self.logdir)
self.multiplexer.AddRunsFromDirectory(self.logdir)
self.multiplexer.Reload()
run_dir = os.path.join(
plugin_asset_util.PluginDirectory(
self.logdir, profile_plugin.ProfilePlugin.plugin_name
),
'abc',
)
cache_version_file_path = os.path.join(
run_dir, profile_plugin.CACHE_VERSION_FILE
)

# Check if the cache_version.txt file doesn't exists.
self.assertFalse(os.path.exists(cache_version_file_path))

# Check if first run generates a cache file.
_, _, _ = self.plugin.data_impl(
utils.make_data_request(run='abc', tool='overview_page', host='host1')
)
self.assertTrue(os.path.exists(cache_version_file_path))
with open(cache_version_file_path, 'r') as f:
self.assertEqual(f.read(), version.__version__)
cache_file_first_run_timestamp = os.path.getmtime(cache_version_file_path)

# Check if the second run generates a cache file.
_, _, _ = self.plugin.data_impl(
utils.make_data_request(run='abc', tool='overview_page', host='host1')
)
self.assertTrue(os.path.exists(cache_version_file_path))
with open(cache_version_file_path, 'r') as f:
self.assertEqual(f.read(), version.__version__)
self.assertEqual(
cache_file_first_run_timestamp,
os.path.getmtime(cache_version_file_path),
)

# Check if the use_saved_result=False generates a cache file.
_, _, _ = self.plugin.data_impl(
utils.make_data_request(
run='abc',
tool='overview_page',
host='host1',
use_saved_result='False',
)
)
self.assertTrue(os.path.exists(cache_version_file_path))
with open(cache_version_file_path, 'r') as f:
self.assertEqual(f.read(), version.__version__)
self.assertLess(
cache_file_first_run_timestamp,
os.path.getmtime(cache_version_file_path),
)

# Overwrite the cache_version.txt file with an old version.
with open(cache_version_file_path, 'w') as f:
f.write('1.0.0')
_, _, _ = self.plugin.data_impl(
utils.make_data_request(run='abc', tool='overview_page', host='host1')
)
self.assertTrue(os.path.exists(cache_version_file_path))
with open(cache_version_file_path, 'r') as f:
self.assertEqual(f.read(), version.__version__)

def testActive(self):

def wait_for_thread():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@ def create_profile_plugin(logdir,
return profile_plugin.ProfilePlugin(context)


def make_data_request(run, tool, host=None):
def make_data_request(run, tool, host=None, use_saved_result=None):
"""Creates a werkzeug.Request to pass as argument to ProfilePlugin.data_impl.

Args:
run: Front-end run name.
tool: ProfilePlugin tool, e.g., 'trace_viewer'.
host: Host that generated the profile data, e.g., 'localhost'.
use_saved_result: Whether to use cache.

Returns:
A werkzeug.Request to pass to ProfilePlugin.data_impl.
Expand All @@ -73,4 +74,6 @@ def make_data_request(run, tool, host=None):
req.args = {'run': run, 'tag': tool}
if host:
req.args['host'] = host
if use_saved_result is not None:
req.args['use_saved_result'] = use_saved_result
return req
2 changes: 2 additions & 0 deletions xprof/convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,8 @@ cc_test(
":repository",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
"@org_xprof//plugin/tensorboard_plugin_profile/protobuf:op_stats_proto_cc",
"@tsl//tsl/platform:path",
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
"@xla//xla/tsl/platform:errors",
"@xla//xla/tsl/platform:status",
Expand Down
21 changes: 21 additions & 0 deletions xprof/convert/repository.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,26 @@ absl::StatusOr<std::pair<bool, std::string>> SessionSnapshot::HasCacheFile(
return std::pair<bool, std::string>(false, std::string());
}

absl::Status SessionSnapshot::ClearCacheFiles() const {
if (!has_accessible_run_dir_) return absl::OkStatus();

// Delete all the cache files in session run directory for all cache types
std::vector<std::string> results;
TF_RETURN_IF_ERROR(::tsl::Env::Default()->GetChildren(
std::string(GetSessionRunDir()), &results));

for (const std::string& path : results) {
std::string file_path = tsl::io::JoinPath(GetSessionRunDir(), path);
for (const auto& format : *kHostDataSuffixes) {
if (absl::EndsWith(path, format.second)) {
TF_RETURN_IF_ERROR(tsl::Env::Default()->DeleteFile(file_path));
break;
}
}
}

return absl::OkStatus();
}

} // namespace profiler
} // namespace tensorflow
2 changes: 2 additions & 0 deletions xprof/convert/repository.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class SessionSnapshot {
absl::StatusOr<std::pair<bool, std::string>> HasCacheFile(
StoredDataType data_type) const;

absl::Status ClearCacheFiles() const;

template <typename T>
absl::Status WriteBinaryProto(const StoredDataType data_type,
const std::string host, T& proto) const {
Expand Down
36 changes: 36 additions & 0 deletions xprof/convert/repository_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ limitations under the License.
#include "absl/status/status.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/status.h"
#include "tsl/platform/path.h"
#include "tsl/profiler/protobuf/xplane.pb.h"
#include "plugin/tensorboard_plugin_profile/protobuf/op_stats.pb.h"

namespace tensorflow {
namespace profiler {
Expand Down Expand Up @@ -137,6 +139,40 @@ TEST(Repository, MismatchedXSpaceAndPath) {
EXPECT_THAT(session_snapshot_or.status().message(), error);
}

TEST(Repository, ClearCacheFiles) {
// Create a temp directory for the test.
std::string temp_dir = ::testing::TempDir();
std::string profile_dir = tsl::io::JoinPath(temp_dir, "log/plugins/profile");
TF_CHECK_OK(tsl::Env::Default()->RecursivelyCreateDir(profile_dir));
std::string xplane_path =
tsl::io::JoinPath(profile_dir, "hostname0.xplane.pb");

std::vector<std::unique_ptr<XSpace>> xspaces;
// prepare host 0.
auto space0 = std::make_unique<XSpace>();
*(space0->add_hostnames()) = "hostname0";
// with index 1 which shouldn't impact the space finding by name.
xspaces.push_back(std::move(space0));
auto session_snapshot_or =
SessionSnapshot::Create({xplane_path}, /*xspaces=*/std::nullopt);
TF_CHECK_OK(session_snapshot_or.status());

// Generate Dummy HLO OpStats file.
OpStats op_stats;
op_stats.set_allocated_run_environment(new RunEnvironment());
TF_CHECK_OK(session_snapshot_or.value().WriteBinaryProto(
StoredDataType::OP_STATS, "hostname0", op_stats));
auto opt_statsfile_path = session_snapshot_or.value().GetHostDataFilePath(
StoredDataType::OP_STATS, "hostname0");
EXPECT_TRUE(opt_statsfile_path.value().has_value());

// Check that the cache file should be deleted
TF_CHECK_OK(session_snapshot_or.value().ClearCacheFiles());
opt_statsfile_path = session_snapshot_or.value().GetHostDataFilePath(
StoredDataType::OP_STATS, "hostname0");
EXPECT_FALSE(opt_statsfile_path.value().has_value());
}

} // namespace
} // namespace profiler
} // namespace tensorflow
Loading