Skip to content

Commit 605c82e

Browse files
bani-intelaipgsayantan-nervana
authored andcommitted
Bani / Support tf_ellipsis_mask in stridedslice NGTF-2404 (#369)
1 parent b40f6ee commit 605c82e

File tree

6 files changed

+194
-278
lines changed

6 files changed

+194
-278
lines changed

ngraph_bridge/ngraph_builder.cc

Lines changed: 90 additions & 256 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "ngraph/op/argmin.hpp"
2929
#include "ngraph/op/experimental/layers/interpolate.hpp"
3030
#include "ngraph/op/util/logical_reduction.hpp"
31+
#include "ngraph/slice_plan.hpp"
3132

3233
#include "logging/ngraph_log.h"
3334
#include "ngraph_bridge/ngraph_api.h"
@@ -4410,7 +4411,6 @@ static Status TranslateSqueezeOp(const Node* op,
44104411
static Status TranslateStridedSliceOp(
44114412
const Node* op, const std::vector<const Tensor*>& static_input_map,
44124413
Builder::OpMap& ng_op_map) {
4413-
// TODO: implement new_axis_mask, ellipsis_mask
44144414
shared_ptr<ng::Node> ng_input;
44154415
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, &ng_input));
44164416

@@ -4442,277 +4442,111 @@ static Status TranslateStridedSliceOp(
44424442
TF_RETURN_IF_ERROR(
44434443
GetStaticInputVector(op, 3, static_input_map, &stride_vec));
44444444

4445-
auto& input_shape = ng_input->get_shape();
4446-
4447-
// Summary: Convert tf indexes (-inf, inf) to clamped_begin_idx [0, d] and
4448-
// clamped_end_idx [-1, d], which are then converted to ngraph indexes [0,
4449-
// d]
4450-
// tf->ng is done through tf_to_ng, which calls clamper, which converts
4451-
// tf->clamped
4452-
4453-
// Graph/function for tf->cmapled
4454-
// | ....... <-- y = max_val (max_val = 5)
4455-
// .| .
4456-
// . | .
4457-
// . | . <-- y = x>=0 ? x : x+max_val
4458-
// . |.
4459-
// -.-.-.----.------------ <-- y = 0 (for inclusive)
4460-
// * * | <-- y = -1 (for exclusive)
4461-
// |
4462-
// X axis: TF indexes. Y axis: Clamped indexes
4463-
4464-
// clamper is a function that implements the graph above.
4465-
// For inclusive, the graph is clamped at 0 and dim-1
4466-
// Given dimension d, [0, d-1] are valid locations.
4467-
// -1 represents std::rend(). d represents std::end().
4468-
// These two are useful for representing exclusive boundaries for end-ranges
4469-
// Example for dim = 3:
4470-
// ranges: (-inf,-d)| [-d,0) |[0,d-1]|(d-1,inf)
4471-
// TF index: -5 -4 |-3 -2 -1 | 0 1 2 | 3 4 5
4472-
// clamped begin (inclusive): 0 0 | 0 1 2 | 0 1 2 | 3 3 3
4473-
// clamped end (exclusive): -1 -1 | 0 1 2 | 0 1 2 | 3 3 3
4474-
auto clamper = [](int idx, size_t dim, bool inclusive) {
4475-
// if idx is in [-(d-1), d-1], then its same for both inclusive and
4476-
// exclusive
4477-
// The first 2 cases breaks down this range
4478-
if (idx >= 0 && idx <= (static_cast<int>(dim) - 1)) {
4479-
return idx;
4480-
} else if (idx < 0 &&
4481-
idx + static_cast<int>(dim) >=
4482-
0) { // careful not to do idx >= -dim
4483-
// (since dim is unsigned)
4484-
return idx + static_cast<int>(
4485-
dim); // Type casting to int to enable unambiguous auto
4486-
// type inference of return type
4487-
} else if (idx > static_cast<int>(dim) - 1) {
4488-
return static_cast<int>(dim);
4489-
} else if (idx + static_cast<int>(dim) < 0) {
4490-
// The next case handles the clamping (differently for inclusive and
4491-
// exclusive cases)
4492-
4493-
// careful not to do idx < -dim (since dim is unsigned)
4494-
return 0 - (inclusive ? 0 : 1);
4495-
}
4496-
// Default case
4497-
return 0;
4498-
};
4499-
4500-
auto tf_to_ng = [clamper](int tf_begin_idx, int tf_end_idx, int tf_stride,
4501-
size_t dim, bool begin_mask, bool end_mask,
4502-
bool shrink_mask) {
4503-
// if begin mask is present, depending on stride sign use 0 (std::begin)
4504-
// or
4505-
// dim-1 (std::rbegin)
4506-
// clamped_end_idx could line in [-1, d]
4507-
int tf_ignore_begin_if_needed =
4508-
begin_mask ? (tf_stride > 0 ? 0 : dim - 1) : tf_begin_idx;
4509-
// if end mask is present, depending on stride sign use -1 (std::rend) or
4510-
// dim (std::end).
4511-
// However note, we cannot set to -1, since it has another meaning, hence
4512-
// setting to -(dim+1), which would translate to -1 in clamped coordinates
4513-
// take care to convert dim from sixze_t to int
4514-
int tf_ignore_end_if_needed =
4515-
end_mask ? (tf_stride > 0 ? dim : (-((int)dim + 1))) : tf_end_idx;
4516-
// using size_t for clamped_begin_idx because: clamped_begin_idx is
4517-
// inclusive, so it must lie in [0, dim-1]
4518-
size_t clamped_begin_idx = clamper(tf_ignore_begin_if_needed, dim, true);
4519-
int64 clamped_end_idx =
4520-
clamper(shrink_mask ? clamped_begin_idx + 1 : tf_ignore_end_if_needed,
4521-
dim, false);
4522-
4523-
// Now we have converted semantically non-monotonic and unbounded TF
4524-
// indexes
4525-
// (-inf, inf) to bounded and monotonic clamped indexes [-1, d]
4526-
// Now we need to convert clamped indexes [-1, d] to ngraph indexes [0, d]
4527-
// (taking care of reversal in case of negative strides)
4528-
4529-
size_t needs_reverse = 0;
4530-
size_t ng_begin_idx, ng_end_idx;
4531-
4532-
if (!shrink_mask) {
4533-
if ((int)clamped_begin_idx == clamped_end_idx) {
4534-
// Empty due to matching indexes
4535-
ng_begin_idx = clamped_begin_idx;
4536-
// Type safety: clamped_begin_idx == clamped_end_idx implies,
4537-
// clamped_end_idx!=-1 (since clamped_begin_idx cannot be -1), hence
4538-
// end
4539-
// index assignment is type safe
4540-
ng_end_idx = clamped_end_idx;
4541-
} else { // In the whole of this else: clamped_begin_idx !=
4542-
// clamped_end_idx, so !(a < b) iff a > b and vice versa when
4543-
// comparing the indexes
4544-
// take care to use (int) typecase when comparing int and size_t
4545-
if (((int)clamped_begin_idx < clamped_end_idx) != (tf_stride > 0)) {
4546-
// Empty due to mismatching directions
4547-
ng_begin_idx = clamped_begin_idx;
4548-
// Type safe: since clamped_begin_idx is size_t (>0)
4549-
// [0:-4:1] in TF would convert to [0:-1:1] in clamped domain. hence
4550-
// we do not assign ng_end_idx = clamped_end_idx (which would not be
4551-
// type safe due to the -1)
4552-
ng_end_idx = clamped_begin_idx;
4553-
// Any assignment where ng_begin_idx = ng_end_idx = x (where 0 <= x
4554-
// <=
4555-
// d-1) would have worked for the 2 empty cases above
4556-
}
4557-
// Anything after this is non-empty. Anything before this has dealt
4558-
// with
4559-
// empty cases
4560-
else {
4561-
// in this case either (clamped_begin_idx < clamped_end_idx &&
4562-
// tf_stride > 0) or (clamped_begin_idx > clamped_end_idx &&
4563-
// tf_stride
4564-
// < 0)
4565-
// that is clamped_begin_idx < clamped_end_idx <==> tf_stride > 0.
4566-
// hence using only 1 of the clauses is enough
4567-
if (tf_stride > 0) {
4568-
ng_begin_idx = clamped_begin_idx;
4569-
// Type safety: tf_stride > 0 ==> clamped_begin_idx <
4570-
// clamped_end_idx. clamped_begin_idx could be 0,
4571-
// which means clamped_end_idx > 0. Hence type-safe
4572-
ng_end_idx = clamped_end_idx;
4573-
} else { // clamped_begin_idx > clamped_end_idx, tf_stride < 0
4574-
4575-
// clamped_begin_idx is [0, d] && clamped_begin_idx >
4576-
// clamped_end_idx,
4577-
// which implies clamped_end_idx is [-1,d-1]
4578-
// Type safety: With clamped_end_idx in [-1,d-1],
4579-
// dim - 1 - clamped_end_idx is in [0, dim]. Hence type safe
4580-
ng_end_idx = dim - 1 - clamped_end_idx;
4581-
4582-
if (clamped_begin_idx == dim) {
4583-
clamped_begin_idx = dim - 1;
4584-
}
4585-
// Note clamped_begin_idx != dim here.
4586-
// If clamped_begin_idx==dim && clamped_end_idx==dim, then "Empty
4587-
// due to matching indexes" handles it
4588-
// If clamped_begin_idx==dim && clamped_end_idx<dim, then 2 cases:
4589-
// tf_stride > 0: then "Empty due to mismatching directions"
4590-
// handles it
4591-
// tf_stride < 0: Then we set it to dim-1 above
4592-
// Consider the case of dim=3, where in tf notation we have:
4593-
// [4:1:-1], in clampe notation, we get [3:1:-1], which really
4594-
// means
4595-
// [2:1:-1]
4596-
4597-
// Type safety: Since clamped_begin_idx is [0, d-1] here, it is
4598-
// type
4599-
// safe
4600-
ng_begin_idx = dim - 1 - clamped_begin_idx;
4601-
needs_reverse = 1;
4602-
}
4603-
}
4445+
// Desired implementation ==>
4446+
// SaveNgOp(ng_op_map, op->name(),
4447+
// ConstructNgNode<ng::op::StridedSlice>(op->name(), begin_vec,
4448+
// end_vec, stride_vec,
4449+
// tf_begin_mask, tf_end_mask,
4450+
// tf_new_axis_mask, tf_shrink_axis_mask,
4451+
// tf_ellipsis_mask));
4452+
4453+
// Temporarily we are borrowing this implementation from nGraph-core until
4454+
// ng::op::StridedSlice is released for use in ngraph-bridge
4455+
4456+
auto convert_mask_to_axes = [](const int mask) {
4457+
ng::AxisSet axes{};
4458+
for (auto i = 0; i < sizeof(int) * 8; ++i) {
4459+
if ((unsigned char)(mask >> i & 0x01) == 1) {
4460+
axes.emplace(i);
46044461
}
4605-
} else {
4606-
// cases when clamped indexes are in [0,d] and hence can be directly
4607-
// copied
4608-
// TODO: what about tf_begin=d, shrink=T, then clamped_end_idx = d, so a
4609-
// 0-d axis.
4610-
// But since shrink is on, that is reshaped and the 0-d axis is removed?
4611-
// Is that a valid config, as shrink_axis must get an axis with dim = 1,
4612-
// right?
4613-
4614-
ng_begin_idx = clamped_begin_idx;
4615-
ng_end_idx = clamped_end_idx;
46164462
}
4617-
return std::make_tuple(ng_begin_idx, ng_end_idx, std::abs(tf_stride),
4618-
needs_reverse);
4463+
return axes;
46194464
};
46204465

4621-
auto extract_bit = [](int bit_mask, int bit_location) {
4622-
return (bit_mask & (1 << bit_location)) != 0;
4623-
};
4466+
ng::Shape input_shape = ng_input->get_shape();
46244467

4625-
auto dim_vec = ng_input->get_shape();
4626-
auto in_rank = dim_vec.size();
4627-
4628-
if (begin_vec.size() > in_rank) {
4629-
return errors::InvalidArgument("Index out of range using input dim ",
4630-
begin_vec.size(), "; input has only ",
4631-
in_rank, " dims");
4632-
}
4633-
4634-
// TODO/Note/Question: Are begin, end and stride vectors are of equal length
4635-
4636-
// begin, end and stride vectors may not have same size as input rank, hence
4637-
// initialize them with 0, dim and 1 respectively
4638-
vector<size_t> ng_begin_vec(in_rank, 0), ng_stride_vec(in_rank, 1);
4639-
vector<size_t> ng_end_vec(dim_vec);
4640-
vector<size_t> ng_needs_reversal(in_rank, 0); // should have been a
4641-
// vector<bool>, but it is
4642-
// optimized, so tie won't
4643-
// work. Hence using size_t
4644-
for (size_t dim_idx = 0; dim_idx < begin_vec.size(); dim_idx++) {
4645-
std::tie(ng_begin_vec[dim_idx], ng_end_vec[dim_idx], ng_stride_vec[dim_idx],
4646-
ng_needs_reversal[dim_idx]) =
4647-
tf_to_ng(begin_vec[dim_idx], end_vec[dim_idx], stride_vec[dim_idx],
4648-
dim_vec[dim_idx], extract_bit(tf_begin_mask, dim_idx),
4649-
extract_bit(tf_end_mask, dim_idx),
4650-
extract_bit(tf_shrink_axis_mask, dim_idx));
4651-
}
4652-
4653-
// filter out negative stride dimensions
4654-
vector<size_t> neg_strides;
4655-
for (size_t dim_idx = 0; dim_idx < in_rank; dim_idx++) {
4656-
if (ng_needs_reversal[dim_idx]) {
4657-
neg_strides.push_back(dim_idx);
4468+
std::vector<int64_t> begin_vec_longint(begin_vec.begin(), begin_vec.end());
4469+
std::vector<int64_t> end_vec_longint(end_vec.begin(), end_vec.end());
4470+
std::vector<int64_t> stride_vec_longint(stride_vec.begin(), stride_vec.end());
4471+
4472+
NGRAPH_VLOG(4) << "Arguments to make_slice_plan: Input shape: " << input_shape
4473+
<< ", begin vector: " << ng::join(begin_vec_longint)
4474+
<< ", end vector: " << ng::join(end_vec_longint)
4475+
<< ", stride vector: " << ng::join(stride_vec_longint)
4476+
<< ", begin mask: " << tf_begin_mask
4477+
<< ", end mask: " << tf_end_mask
4478+
<< ", new axis mask: " << tf_new_axis_mask
4479+
<< ", shrink axis mask: " << tf_shrink_axis_mask
4480+
<< ", ellipsis mask: " << tf_ellipsis_mask;
4481+
4482+
auto in_rank = ng_input->get_shape().size();
4483+
if (tf_new_axis_mask == 0) {
4484+
if (begin_vec_longint.size() > in_rank) {
4485+
return errors::InvalidArgument("Index out of range using input dim ",
4486+
begin_vec_longint.size(),
4487+
"; input has only ", in_rank, " dims");
46584488
}
46594489
}
46604490

4661-
// atleast one stride was negative, in which case reverse the input
4662-
if (neg_strides.size() > 0)
4663-
ng_input =
4664-
ConstructNgNode<ng::op::Reverse>(op->name(), ng_input, neg_strides);
4665-
NGRAPH_VLOG(3) << "NG Lower Vector " << ng::join(ng_begin_vec);
4666-
NGRAPH_VLOG(3) << "NG End Vector " << ng::join(ng_end_vec);
4667-
NGRAPH_VLOG(3) << "NG Stride Vector " << ng::join(ng_stride_vec);
4668-
NGRAPH_VLOG(3) << "NG Needs Reversal: " << ng::join(ng_needs_reversal);
4669-
4670-
std::shared_ptr<ng::Node> ng_strided_slice = ConstructNgNode<ng::op::Slice>(
4671-
op->name(), ng_input, ng_begin_vec, ng_end_vec, ng_stride_vec);
4672-
4673-
if (tf_shrink_axis_mask) {
4674-
int64 shrink_axis_mask = tf_shrink_axis_mask;
4675-
vector<size_t> output_shape;
4676-
4677-
// Note: do not use rank instead of ng_begin_vec.size()
4678-
// since ng_begin_vec.size() can be less than rank, and
4679-
// shrink_mask will have atmost ng_begin_vec.size() elements
4680-
for (size_t i = 0; i < ng_begin_vec.size(); i++) {
4681-
if ((shrink_axis_mask & 1) != 1) {
4682-
output_shape.push_back(ng_end_vec[i] - ng_begin_vec[i]);
4683-
} else {
4684-
// TODO: must it equal 1 or can it be 0 too?
4685-
if (ng_end_vec[i] - ng_begin_vec[i] > 1)
4686-
return errors::InvalidArgument(
4687-
"Trying to shrink specification ", i,
4688-
"where tf begin, end, strides are: ", begin_vec[i], ":",
4689-
end_vec[i], ":", stride_vec[i],
4690-
". nGraph begin, end, stride are: ", ng_begin_vec[i], ":",
4691-
ng_end_vec[i], ":", ng_stride_vec[i],
4692-
". nGraph's begin and end have difference greater than 1");
4693-
}
4694-
shrink_axis_mask >>= 1;
4695-
}
4491+
auto sp = ng::make_slice_plan(
4492+
input_shape, begin_vec_longint, end_vec_longint, stride_vec_longint,
4493+
convert_mask_to_axes(tf_begin_mask), convert_mask_to_axes(tf_end_mask),
4494+
convert_mask_to_axes(tf_new_axis_mask),
4495+
convert_mask_to_axes(tf_shrink_axis_mask),
4496+
convert_mask_to_axes(tf_ellipsis_mask));
4497+
4498+
NGRAPH_VLOG(4) << "Return values of make_slice_plan: begin: "
4499+
<< ng::join(sp.begins) << ", end: " << ng::join(sp.ends)
4500+
<< ", stride: " << ng::join(sp.strides)
4501+
<< ", reshape input shape: " << sp.reshape_in_shape
4502+
<< ", reshape output shape: " << sp.reshape_out_shape
4503+
<< ", reverse axis: " << sp.reverse_axes;
4504+
4505+
// To handle cases like x[2:2], where shape(x) = [1],
4506+
// TF returns shape = [0], empty vector
4507+
// make_slice_plan returns begin=2, end=2, but that is > 1
4508+
// So must clamp them
4509+
// Another example:
4510+
// for dimension 3, Also 2:3:-1 gives 4:4, which will also fail if we try to
4511+
// construct slice. So must clamp to 2:2 etc
4512+
4513+
auto clamp = [](int64_t x, int64_t min, int64_t max) {
4514+
return x > max ? max : (x < min ? min : x);
4515+
};
4516+
for (int i = 0; i < sp.begins.size(); i++) {
4517+
sp.begins[i] = clamp(sp.begins[i], 0, input_shape[i]);
4518+
sp.ends[i] = clamp(sp.ends[i], 0, input_shape[i]);
4519+
}
46964520

4697-
NGRAPH_VLOG(3) << "Shrink axis mask " << tf_shrink_axis_mask;
4698-
ng::Shape ng_final_shape(output_shape);
4699-
ng::AxisVector ng_axis_order(input_shape.size());
4521+
// Need to convert int64_t to size_t
4522+
std::vector<size_t> sp_begins(sp.begins.begin(), sp.begins.end());
4523+
std::vector<size_t> sp_ends(sp.ends.begin(), sp.ends.end());
4524+
std::vector<size_t> sp_strides(sp.strides.begin(), sp.strides.end());
4525+
4526+
shared_ptr<ng::Node> ng_result = ConstructNgNode<ng::op::Slice>(
4527+
op->name(), ng_input, sp_begins, sp_ends, sp_strides);
4528+
4529+
if (sp.reshape_in_shape != sp.reshape_out_shape) {
4530+
ng::Shape ng_out_shape(sp.reshape_out_shape);
4531+
ng::AxisVector ng_axis_order(sp.reshape_in_shape.size());
4532+
// std::iota Fills the range [first, last) with sequentially increasing
4533+
// values,
4534+
// starting with value and repetitively evaluating ++value
47004535
std::iota(ng_axis_order.begin(), ng_axis_order.end(), 0);
47014536

4702-
NGRAPH_VLOG(3) << " Output shape " << ng::join(output_shape);
4537+
NGRAPH_VLOG(3) << " Output shape " << ng::join(ng_out_shape);
47034538
NGRAPH_VLOG(3) << " NG axis order " << ng::join(ng_axis_order);
47044539

4705-
ng_strided_slice = ConstructNgNode<ng::op::Reshape>(
4706-
op->name(), ng_strided_slice, ng_axis_order, ng_final_shape);
4540+
ng_result = ConstructNgNode<ng::op::Reshape>(op->name(), ng_result,
4541+
ng_axis_order, ng_out_shape);
47074542
}
47084543

4709-
// TODO: assert size in this dim was 1
4710-
// TODO: assert new_axis_mask and tf_shrink_axis_mask are not set at the
4711-
// same
4712-
// time?
4713-
// TODO: tf_new_axis_mask can exceed rank
4544+
if (!sp.reverse_axes.empty()) {
4545+
ng_result = ConstructNgNode<ng::op::Reverse>(op->name(), ng_result,
4546+
sp.reverse_axes);
4547+
}
47144548

4715-
SaveNgOp(ng_op_map, op->name(), ng_strided_slice);
4549+
SaveNgOp(ng_op_map, op->name(), ng_result);
47164550
return Status::OK();
47174551
}
47184552

0 commit comments

Comments
 (0)