-
Notifications
You must be signed in to change notification settings - Fork 360
[OSPP 2022] DeepRec supports exporting models to key-value NoSQL databases -> The first feature submission #470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MrRobotsAA
wants to merge
3
commits into
DeepRec-AI:main
Choose a base branch
from
MrRobotsAA:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ limitations under the License. | |
|
||
#include <string> | ||
#include <vector> | ||
#include <iostream> | ||
|
||
#include "tensorflow/core/framework/bounds_check.h" | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
|
@@ -31,6 +32,7 @@ limitations under the License. | |
#include "tensorflow/core/platform/env.h" | ||
#include "tensorflow/core/platform/logging.h" | ||
#include "tensorflow/core/platform/types.h" | ||
#include "tensorflow/core/util/tensor_bundle/db_writer.h" | ||
#include "tensorflow/core/util/saved_tensor_slice_util.h" | ||
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" | ||
#include "tensorflow/core/util/tensor_slice_reader.h" | ||
|
@@ -106,25 +108,19 @@ class SaveV2 : public OpKernel { | |
} | ||
|
||
template <typename TKey, typename TValue> | ||
void DumpEvWithGlobalStep(OpKernelContext* context, int variable_index, | ||
const string& tensor_name, BundleWriter& writer, | ||
DataType global_step_type) { | ||
void DumpEvWithGlobalStep(OpKernelContext* context, int variable_index, const string& tensor_name, BundleWriter& writer, DataType global_step_type) { | ||
if (global_step_type == DT_INT32) { | ||
DumpEv<TKey, TValue, int32>(context, variable_index, | ||
tensor_name, writer); | ||
DumpEv<TKey, TValue, int32>(context, variable_index, tensor_name, writer); | ||
} else { | ||
DumpEv<TKey, TValue, int64>(context, variable_index, | ||
tensor_name, writer); | ||
DumpEv<TKey, TValue, int64>(context, variable_index, tensor_name, writer); | ||
} | ||
} | ||
|
||
template <typename TKey, typename TValue, typename TGlobalStep> | ||
void DumpEv(OpKernelContext* context, int variable_index, | ||
const string& tensor_name, BundleWriter& writer) { | ||
void DumpEv(OpKernelContext* context, int variable_index, const string& tensor_name, BundleWriter& writer) { | ||
EmbeddingVar<TKey, TValue>* variable = nullptr; | ||
OP_REQUIRES_OK(context, | ||
LookupResource(context, | ||
HandleFromInput(context, variable_index), &variable)); | ||
LookupResource(context, HandleFromInput(context, variable_index), &variable)); | ||
const Tensor& global_step = context->input(3); | ||
Tensor part_offset_tensor; | ||
context->allocate_temp(DT_INT32, | ||
|
@@ -136,8 +132,7 @@ class SaveV2 : public OpKernel { | |
OP_REQUIRES_OK(context, variable->Shrink()); | ||
else | ||
OP_REQUIRES_OK(context, variable->Shrink(global_step_scalar)); | ||
OP_REQUIRES_OK(context, DumpEmbeddingValues(variable, tensor_name, | ||
&writer, &part_offset_tensor)); | ||
OP_REQUIRES_OK(context, DumpEmbeddingValues(variable, tensor_name, &writer, &part_offset_tensor)); | ||
} | ||
|
||
void Compute(OpKernelContext* context) override { | ||
|
@@ -146,38 +141,57 @@ class SaveV2 : public OpKernel { | |
const Tensor& shape_and_slices = context->input(2); | ||
ValidateInputs(true /* is save op */, context, prefix, tensor_names, | ||
shape_and_slices); | ||
if (!context->status().ok()) return; | ||
|
||
const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices. | ||
const int num_tensors = static_cast<int>(tensor_names.NumElements()); | ||
const int num_tensors = static_cast<int>(tensor_names.NumElements()); //获取tensor个数 | ||
|
||
const string& prefix_string = prefix.scalar<tstring>()(); | ||
const auto& tensor_names_flat = tensor_names.flat<tstring>(); | ||
const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>(); | ||
|
||
BundleWriter writer(Env::Default(), prefix_string); | ||
const int Nosql_Marker = 0; | ||
|
||
auto tempstate = random::New64(); | ||
string db_prefix_tmp = strings::StrCat(prefix_string,"--temp",tempstate); | ||
DBWriter dbwriter(Env::Default(), prefix_string,db_prefix_tmp); | ||
OP_REQUIRES_OK(context, dbwriter.status()); | ||
|
||
BundleWriter writer(Env::Default(), prefix_string,db_prefix_tmp); | ||
OP_REQUIRES_OK(context, writer.status()); | ||
VLOG(1) << "BundleWriter, prefix_string: " << prefix_string; | ||
|
||
|
||
int start_index = 0; | ||
if (has_ev_) { | ||
start_index = 1; | ||
} | ||
|
||
|
||
int start_ev_key_index = 0; | ||
|
||
|
||
for (int i = start_index; i < num_tensors; ++i) { | ||
const string& tensor_name = tensor_names_flat(i); | ||
if (tensor_types_[i] == DT_RESOURCE) { | ||
const string& tensor_name = tensor_names_flat(i); | ||
|
||
|
||
if (tensor_types_[i] == DT_RESOURCE) | ||
{ | ||
MrRobotsAA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
auto& handle = HandleFromInput(context, i + kFixedInputs); | ||
if (IsHandle<EmbeddingVar<int64, float>>(handle)) { | ||
EmbeddingVar<int64, float>* variable = nullptr; | ||
OP_REQUIRES_OK(context, | ||
LookupResource(context, HandleFromInput(context, i + kFixedInputs), &variable)); | ||
core::ScopedUnref unref_variable(variable); | ||
const Tensor& global_step = context->input(3); | ||
Tensor part_offset_tensor; | ||
context->allocate_temp(DT_INT32, | ||
TensorShape({kSavedPartitionNum + 1}), | ||
&part_offset_tensor); | ||
|
||
if (ev_key_types_[start_ev_key_index] == DT_INT32) { | ||
DumpEvWithGlobalStep<int32, float>(context, | ||
i + kFixedInputs, tensor_name, writer, tensor_types_[0]); | ||
DumpEvWithGlobalStep<int32, float>(context, i + kFixedInputs, tensor_name, writer, tensor_types_[0]); | ||
} else if (ev_key_types_[start_ev_key_index] == DT_INT64) { | ||
DumpEvWithGlobalStep<int64, float>(context, | ||
i + kFixedInputs, tensor_name, writer, tensor_types_[0]); | ||
DumpEvWithGlobalStep<int64, float>(context, i + kFixedInputs, tensor_name, writer, tensor_types_[0]); | ||
} | ||
} else if (IsHandle<HashTableResource>(handle)) { | ||
} | ||
else if (IsHandle<HashTableResource>(handle)) { | ||
auto handles = context->input(i + kFixedInputs).flat<ResourceHandle>(); | ||
int tensible_size = handles.size() - 1; | ||
std::vector<core::ScopedUnref> unrefs; | ||
|
@@ -205,7 +219,6 @@ class SaveV2 : public OpKernel { | |
|
||
OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( | ||
shape_spec, &shape, &slice, &slice_shape)); | ||
|
||
std::vector<string> names_lst = str_util::Split(tensor_name, '|'); | ||
for (auto&& name : names_lst) { | ||
std::vector<string> tensor_name_x = | ||
|
@@ -218,15 +231,16 @@ class SaveV2 : public OpKernel { | |
OP_REQUIRES_OK(context, SaveHashTable( | ||
&writer, hashtable, tensibles, table_name, tensible_name, | ||
slice.start(0), slice.length(0), slice_shape.dim_size(0))); | ||
|
||
|
||
} | ||
} else if (IsHandle<HashTableAdmitStrategyResource>(handle)) { | ||
} | ||
else if (IsHandle<HashTableAdmitStrategyResource>(handle)) { | ||
HashTableAdmitStrategyResource* resource; | ||
OP_REQUIRES_OK(context, | ||
LookupResource(context, | ||
HandleFromInput(context, i + kFixedInputs), &resource)); | ||
LookupResource(context, HandleFromInput(context, i + kFixedInputs), &resource)); | ||
HashTableAdmitStrategy* strategy = resource->Internal(); | ||
BloomFilterAdmitStrategy* bf = | ||
dynamic_cast<BloomFilterAdmitStrategy*>(strategy); | ||
BloomFilterAdmitStrategy* bf = dynamic_cast<BloomFilterAdmitStrategy*>(strategy); | ||
CHECK(bf != nullptr) << "Cannot save Non-BloomFilterAdmitStrategy!"; | ||
|
||
string shape_spec = shape_and_slices_flat(i); | ||
|
@@ -240,33 +254,57 @@ class SaveV2 : public OpKernel { | |
&writer, bf, tensor_name, slice.start(0), | ||
slice.length(0), slice_shape.dim_size(0))); | ||
} | ||
|
||
|
||
start_ev_key_index++; | ||
} else { | ||
} | ||
else | ||
{ | ||
const Tensor& tensor = context->input(i + kFixedInputs); | ||
|
||
if (!shape_and_slices_flat(i).empty()) { | ||
const string& shape_spec = shape_and_slices_flat(i); | ||
TensorShape shape; | ||
TensorSlice slice(tensor.dims()); | ||
TensorShape slice_shape; | ||
|
||
|
||
OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( | ||
shape_spec, &shape, &slice, &slice_shape)); | ||
shape_spec, &shape, &slice, &slice_shape)); | ||
OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()), | ||
errors::InvalidArgument("Slice in shape_and_slice " | ||
"specification does not match the " | ||
"shape of the tensor to save: ", | ||
shape_spec, ", tensor: ", | ||
tensor.shape().DebugString())); | ||
errors::InvalidArgument("Slice in shape_and_slice " | ||
"specification does not match the " | ||
"shape of the tensor to save: ", | ||
shape_spec, ", tensor: ", | ||
tensor.shape().DebugString())); | ||
|
||
OP_REQUIRES_OK(context, | ||
writer.AddSlice(tensor_name, shape, slice, tensor)); | ||
} else { | ||
OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor)); | ||
if(Nosql_Marker==1){ | ||
|
||
OP_REQUIRES_OK(context, | ||
dbwriter.AddSlice(tensor_name, shape, slice, tensor,"slice_tensor")); | ||
} | ||
else{ | ||
|
||
OP_REQUIRES_OK(context, | ||
writer.AddSlice(tensor_name, shape, slice, tensor)); | ||
} | ||
} | ||
else { | ||
if(Nosql_Marker==1){ | ||
OP_REQUIRES_OK(context, dbwriter.Add(tensor_name, tensor,"normal_tensor")); | ||
} | ||
else{ | ||
string tmp_dbfile_prefix_string = strings::StrCat(prefix_string,"--temp",tempstate,"--data--0--1","--tensor--",tensor_name); | ||
OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor,tmp_dbfile_prefix_string)); | ||
} | ||
} | ||
} | ||
} | ||
OP_REQUIRES_OK(context, writer.Finish()); | ||
if(Nosql_Marker==1){ | ||
OP_REQUIRES_OK(context, dbwriter.Finish()); | ||
} | ||
else{ | ||
OP_REQUIRES_OK(context, writer.Finish()); | ||
} | ||
} | ||
private: | ||
DataTypeVector tensor_types_; | ||
|
@@ -278,8 +316,7 @@ REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2); | |
// Restores a list of named tensors from a tensor bundle (V2 checkpoint format). | ||
class RestoreHashTableOp : public AsyncOpKernel { | ||
public: | ||
explicit RestoreHashTableOp(OpKernelConstruction* context) | ||
: AsyncOpKernel(context) { | ||
explicit RestoreHashTableOp(OpKernelConstruction* context) : AsyncOpKernel(context) { | ||
OP_REQUIRES_OK(context, context->GetAttr("clear", &clear_)); | ||
} | ||
|
||
|
@@ -289,8 +326,7 @@ class RestoreHashTableOp : public AsyncOpKernel { | |
const Tensor& shape_and_slices = context->input(2); | ||
const Tensor& handles = context->input(3); | ||
const string& prefix_string = prefix.scalar<string>()(); | ||
const string& shape_and_slices_string = | ||
shape_and_slices.scalar<string>()(); | ||
const string& shape_and_slices_string = shape_and_slices.scalar<string>()(); | ||
auto tensor_names_flat = tensor_names.flat<string>(); | ||
auto handles_flat = handles.flat<ResourceHandle>(); | ||
|
||
|
@@ -376,8 +412,7 @@ class RestoreHashTableOp : public AsyncOpKernel { | |
private: | ||
bool clear_; | ||
}; | ||
REGISTER_KERNEL_BUILDER(Name("RestoreHashTable").Device(DEVICE_CPU), | ||
RestoreHashTableOp); | ||
REGISTER_KERNEL_BUILDER(Name("RestoreHashTable").Device(DEVICE_CPU), RestoreHashTableOp); | ||
|
||
class RestoreBloomFilterOp : public AsyncOpKernel { | ||
public: | ||
|
@@ -408,8 +443,7 @@ class RestoreBloomFilterOp : public AsyncOpKernel { | |
OP_REQUIRES_OK_ASYNC( | ||
context, LookupResource(context, handle_flat, &resource), done); | ||
strategy = dynamic_cast<BloomFilterAdmitStrategy*>(resource->Internal()); | ||
CHECK(strategy != nullptr) | ||
<< "Cannot restore BloomFilter from another strategy"; | ||
CHECK(strategy != nullptr) << "Cannot restore BloomFilter from another strategy"; | ||
} | ||
Status st = RestoreBloomFilter( | ||
reader.get(), strategy, tensor_name_flat, slice.start(0), | ||
|
@@ -418,8 +452,7 @@ class RestoreBloomFilterOp : public AsyncOpKernel { | |
done(); | ||
} | ||
}; | ||
REGISTER_KERNEL_BUILDER(Name("RestoreBloomFilter").Device(DEVICE_CPU), | ||
RestoreBloomFilterOp); | ||
REGISTER_KERNEL_BUILDER(Name("RestoreBloomFilter").Device(DEVICE_CPU), RestoreBloomFilterOp); | ||
|
||
// Restores a list of named tensors from a tensor bundle (V2 checkpoint format). | ||
class RestoreV2 : public OpKernel { | ||
|
@@ -438,7 +471,6 @@ class RestoreV2 : public OpKernel { | |
" expected dtypes.")); | ||
ValidateInputs(false /* not save op */, context, prefix, tensor_names, | ||
shape_and_slices); | ||
if (!context->status().ok()) return; | ||
|
||
const string& prefix_string = prefix.scalar<tstring>()(); | ||
|
||
|
@@ -501,7 +533,9 @@ class MergeV2Checkpoints : public OpKernel { | |
const string& merged_prefix = destination_prefix.scalar<tstring>()(); | ||
OP_REQUIRES_OK( | ||
context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix)); | ||
//合并不同的checkpoint源文件 | ||
|
||
//删除旧的目录 | ||
if (delete_old_dirs_) { | ||
const string merged_dir(io::Dirname(merged_prefix)); | ||
for (const string& input_prefix : input_prefixes) { | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单行不要超过80字符,还原之前的写法
下同