|
28 | 28 | #include "ngraph/op/argmin.hpp" |
29 | 29 | #include "ngraph/op/experimental/layers/interpolate.hpp" |
30 | 30 | #include "ngraph/op/util/logical_reduction.hpp" |
| 31 | +#include "ngraph/slice_plan.hpp" |
31 | 32 |
|
32 | 33 | #include "logging/ngraph_log.h" |
33 | 34 | #include "ngraph_bridge/ngraph_api.h" |
@@ -4410,7 +4411,6 @@ static Status TranslateSqueezeOp(const Node* op, |
4410 | 4411 | static Status TranslateStridedSliceOp( |
4411 | 4412 | const Node* op, const std::vector<const Tensor*>& static_input_map, |
4412 | 4413 | Builder::OpMap& ng_op_map) { |
4413 | | - // TODO: implement new_axis_mask, ellipsis_mask |
4414 | 4414 | shared_ptr<ng::Node> ng_input; |
4415 | 4415 | TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, &ng_input)); |
4416 | 4416 |
|
@@ -4442,277 +4442,111 @@ static Status TranslateStridedSliceOp( |
4442 | 4442 | TF_RETURN_IF_ERROR( |
4443 | 4443 | GetStaticInputVector(op, 3, static_input_map, &stride_vec)); |
4444 | 4444 |
|
4445 | | - auto& input_shape = ng_input->get_shape(); |
4446 | | - |
4447 | | - // Summary: Convert tf indexes (-inf, inf) to clamped_begin_idx [0, d] and |
4448 | | - // clamped_end_idx [-1, d], which are then converted to ngraph indexes [0, |
4449 | | - // d] |
4450 | | - // tf->ng is done through tf_to_ng, which calls clamper, which converts |
4451 | | - // tf->clamped |
4452 | | - |
4453 | | - // Graph/function for tf->cmapled |
4454 | | - // | ....... <-- y = max_val (max_val = 5) |
4455 | | - // .| . |
4456 | | - // . | . |
4457 | | - // . | . <-- y = x>=0 ? x : x+max_val |
4458 | | - // . |. |
4459 | | - // -.-.-.----.------------ <-- y = 0 (for inclusive) |
4460 | | - // * * | <-- y = -1 (for exclusive) |
4461 | | - // | |
4462 | | - // X axis: TF indexes. Y axis: Clamped indexes |
4463 | | - |
4464 | | - // clamper is a function that implements the graph above. |
4465 | | - // For inclusive, the graph is clamped at 0 and dim-1 |
4466 | | - // Given dimension d, [0, d-1] are valid locations. |
4467 | | - // -1 represents std::rend(). d represents std::end(). |
4468 | | - // These two are useful for representing exclusive boundaries for end-ranges |
4469 | | - // Example for dim = 3: |
4470 | | - // ranges: (-inf,-d)| [-d,0) |[0,d-1]|(d-1,inf) |
4471 | | - // TF index: -5 -4 |-3 -2 -1 | 0 1 2 | 3 4 5 |
4472 | | - // clamped begin (inclusive): 0 0 | 0 1 2 | 0 1 2 | 3 3 3 |
4473 | | - // clamped end (exclusive): -1 -1 | 0 1 2 | 0 1 2 | 3 3 3 |
4474 | | - auto clamper = [](int idx, size_t dim, bool inclusive) { |
4475 | | - // if idx is in [-(d-1), d-1], then its same for both inclusive and |
4476 | | - // exclusive |
4477 | | - // The first 2 cases breaks down this range |
4478 | | - if (idx >= 0 && idx <= (static_cast<int>(dim) - 1)) { |
4479 | | - return idx; |
4480 | | - } else if (idx < 0 && |
4481 | | - idx + static_cast<int>(dim) >= |
4482 | | - 0) { // careful not to do idx >= -dim |
4483 | | - // (since dim is unsigned) |
4484 | | - return idx + static_cast<int>( |
4485 | | - dim); // Type casting to int to enable unambiguous auto |
4486 | | - // type inference of return type |
4487 | | - } else if (idx > static_cast<int>(dim) - 1) { |
4488 | | - return static_cast<int>(dim); |
4489 | | - } else if (idx + static_cast<int>(dim) < 0) { |
4490 | | - // The next case handles the clamping (differently for inclusive and |
4491 | | - // exclusive cases) |
4492 | | - |
4493 | | - // careful not to do idx < -dim (since dim is unsigned) |
4494 | | - return 0 - (inclusive ? 0 : 1); |
4495 | | - } |
4496 | | - // Default case |
4497 | | - return 0; |
4498 | | - }; |
4499 | | - |
4500 | | - auto tf_to_ng = [clamper](int tf_begin_idx, int tf_end_idx, int tf_stride, |
4501 | | - size_t dim, bool begin_mask, bool end_mask, |
4502 | | - bool shrink_mask) { |
4503 | | - // if begin mask is present, depending on stride sign use 0 (std::begin) |
4504 | | - // or |
4505 | | - // dim-1 (std::rbegin) |
4506 | | - // clamped_end_idx could line in [-1, d] |
4507 | | - int tf_ignore_begin_if_needed = |
4508 | | - begin_mask ? (tf_stride > 0 ? 0 : dim - 1) : tf_begin_idx; |
4509 | | - // if end mask is present, depending on stride sign use -1 (std::rend) or |
4510 | | - // dim (std::end). |
4511 | | - // However note, we cannot set to -1, since it has another meaning, hence |
4512 | | - // setting to -(dim+1), which would translate to -1 in clamped coordinates |
4513 | | - // take care to convert dim from sixze_t to int |
4514 | | - int tf_ignore_end_if_needed = |
4515 | | - end_mask ? (tf_stride > 0 ? dim : (-((int)dim + 1))) : tf_end_idx; |
4516 | | - // using size_t for clamped_begin_idx because: clamped_begin_idx is |
4517 | | - // inclusive, so it must lie in [0, dim-1] |
4518 | | - size_t clamped_begin_idx = clamper(tf_ignore_begin_if_needed, dim, true); |
4519 | | - int64 clamped_end_idx = |
4520 | | - clamper(shrink_mask ? clamped_begin_idx + 1 : tf_ignore_end_if_needed, |
4521 | | - dim, false); |
4522 | | - |
4523 | | - // Now we have converted semantically non-monotonic and unbounded TF |
4524 | | - // indexes |
4525 | | - // (-inf, inf) to bounded and monotonic clamped indexes [-1, d] |
4526 | | - // Now we need to convert clamped indexes [-1, d] to ngraph indexes [0, d] |
4527 | | - // (taking care of reversal in case of negative strides) |
4528 | | - |
4529 | | - size_t needs_reverse = 0; |
4530 | | - size_t ng_begin_idx, ng_end_idx; |
4531 | | - |
4532 | | - if (!shrink_mask) { |
4533 | | - if ((int)clamped_begin_idx == clamped_end_idx) { |
4534 | | - // Empty due to matching indexes |
4535 | | - ng_begin_idx = clamped_begin_idx; |
4536 | | - // Type safety: clamped_begin_idx == clamped_end_idx implies, |
4537 | | - // clamped_end_idx!=-1 (since clamped_begin_idx cannot be -1), hence |
4538 | | - // end |
4539 | | - // index assignment is type safe |
4540 | | - ng_end_idx = clamped_end_idx; |
4541 | | - } else { // In the whole of this else: clamped_begin_idx != |
4542 | | - // clamped_end_idx, so !(a < b) iff a > b and vice versa when |
4543 | | - // comparing the indexes |
4544 | | - // take care to use (int) typecase when comparing int and size_t |
4545 | | - if (((int)clamped_begin_idx < clamped_end_idx) != (tf_stride > 0)) { |
4546 | | - // Empty due to mismatching directions |
4547 | | - ng_begin_idx = clamped_begin_idx; |
4548 | | - // Type safe: since clamped_begin_idx is size_t (>0) |
4549 | | - // [0:-4:1] in TF would convert to [0:-1:1] in clamped domain. hence |
4550 | | - // we do not assign ng_end_idx = clamped_end_idx (which would not be |
4551 | | - // type safe due to the -1) |
4552 | | - ng_end_idx = clamped_begin_idx; |
4553 | | - // Any assignment where ng_begin_idx = ng_end_idx = x (where 0 <= x |
4554 | | - // <= |
4555 | | - // d-1) would have worked for the 2 empty cases above |
4556 | | - } |
4557 | | - // Anything after this is non-empty. Anything before this has dealt |
4558 | | - // with |
4559 | | - // empty cases |
4560 | | - else { |
4561 | | - // in this case either (clamped_begin_idx < clamped_end_idx && |
4562 | | - // tf_stride > 0) or (clamped_begin_idx > clamped_end_idx && |
4563 | | - // tf_stride |
4564 | | - // < 0) |
4565 | | - // that is clamped_begin_idx < clamped_end_idx <==> tf_stride > 0. |
4566 | | - // hence using only 1 of the clauses is enough |
4567 | | - if (tf_stride > 0) { |
4568 | | - ng_begin_idx = clamped_begin_idx; |
4569 | | - // Type safety: tf_stride > 0 ==> clamped_begin_idx < |
4570 | | - // clamped_end_idx. clamped_begin_idx could be 0, |
4571 | | - // which means clamped_end_idx > 0. Hence type-safe |
4572 | | - ng_end_idx = clamped_end_idx; |
4573 | | - } else { // clamped_begin_idx > clamped_end_idx, tf_stride < 0 |
4574 | | - |
4575 | | - // clamped_begin_idx is [0, d] && clamped_begin_idx > |
4576 | | - // clamped_end_idx, |
4577 | | - // which implies clamped_end_idx is [-1,d-1] |
4578 | | - // Type safety: With clamped_end_idx in [-1,d-1], |
4579 | | - // dim - 1 - clamped_end_idx is in [0, dim]. Hence type safe |
4580 | | - ng_end_idx = dim - 1 - clamped_end_idx; |
4581 | | - |
4582 | | - if (clamped_begin_idx == dim) { |
4583 | | - clamped_begin_idx = dim - 1; |
4584 | | - } |
4585 | | - // Note clamped_begin_idx != dim here. |
4586 | | - // If clamped_begin_idx==dim && clamped_end_idx==dim, then "Empty |
4587 | | - // due to matching indexes" handles it |
4588 | | - // If clamped_begin_idx==dim && clamped_end_idx<dim, then 2 cases: |
4589 | | - // tf_stride > 0: then "Empty due to mismatching directions" |
4590 | | - // handles it |
4591 | | - // tf_stride < 0: Then we set it to dim-1 above |
4592 | | - // Consider the case of dim=3, where in tf notation we have: |
4593 | | - // [4:1:-1], in clampe notation, we get [3:1:-1], which really |
4594 | | - // means |
4595 | | - // [2:1:-1] |
4596 | | - |
4597 | | - // Type safety: Since clamped_begin_idx is [0, d-1] here, it is |
4598 | | - // type |
4599 | | - // safe |
4600 | | - ng_begin_idx = dim - 1 - clamped_begin_idx; |
4601 | | - needs_reverse = 1; |
4602 | | - } |
4603 | | - } |
| 4445 | + // Desired implementation ==> |
| 4446 | + // SaveNgOp(ng_op_map, op->name(), |
| 4447 | + // ConstructNgNode<ng::op::StridedSlice>(op->name(), begin_vec, |
| 4448 | + // end_vec, stride_vec, |
| 4449 | + // tf_begin_mask, tf_end_mask, |
| 4450 | + // tf_new_axis_mask, tf_shrink_axis_mask, |
| 4451 | + // tf_ellipsis_mask)); |
| 4452 | + |
| 4453 | + // Temporarily we are borrowing this implementation from nGraph-core until |
| 4454 | + // ng::op::StridedSlice is released for use in ngraph-bridge |
| 4455 | + |
| 4456 | + auto convert_mask_to_axes = [](const int mask) { |
| 4457 | + ng::AxisSet axes{}; |
| 4458 | + for (auto i = 0; i < sizeof(int) * 8; ++i) { |
| 4459 | + if ((unsigned char)(mask >> i & 0x01) == 1) { |
| 4460 | + axes.emplace(i); |
4604 | 4461 | } |
4605 | | - } else { |
4606 | | - // cases when clamped indexes are in [0,d] and hence can be directly |
4607 | | - // copied |
4608 | | - // TODO: what about tf_begin=d, shrink=T, then clamped_end_idx = d, so a |
4609 | | - // 0-d axis. |
4610 | | - // But since shrink is on, that is reshaped and the 0-d axis is removed? |
4611 | | - // Is that a valid config, as shrink_axis must get an axis with dim = 1, |
4612 | | - // right? |
4613 | | - |
4614 | | - ng_begin_idx = clamped_begin_idx; |
4615 | | - ng_end_idx = clamped_end_idx; |
4616 | 4462 | } |
4617 | | - return std::make_tuple(ng_begin_idx, ng_end_idx, std::abs(tf_stride), |
4618 | | - needs_reverse); |
| 4463 | + return axes; |
4619 | 4464 | }; |
4620 | 4465 |
|
4621 | | - auto extract_bit = [](int bit_mask, int bit_location) { |
4622 | | - return (bit_mask & (1 << bit_location)) != 0; |
4623 | | - }; |
| 4466 | + ng::Shape input_shape = ng_input->get_shape(); |
4624 | 4467 |
|
4625 | | - auto dim_vec = ng_input->get_shape(); |
4626 | | - auto in_rank = dim_vec.size(); |
4627 | | - |
4628 | | - if (begin_vec.size() > in_rank) { |
4629 | | - return errors::InvalidArgument("Index out of range using input dim ", |
4630 | | - begin_vec.size(), "; input has only ", |
4631 | | - in_rank, " dims"); |
4632 | | - } |
4633 | | - |
4634 | | - // TODO/Note/Question: Are begin, end and stride vectors are of equal length |
4635 | | - |
4636 | | - // begin, end and stride vectors may not have same size as input rank, hence |
4637 | | - // initialize them with 0, dim and 1 respectively |
4638 | | - vector<size_t> ng_begin_vec(in_rank, 0), ng_stride_vec(in_rank, 1); |
4639 | | - vector<size_t> ng_end_vec(dim_vec); |
4640 | | - vector<size_t> ng_needs_reversal(in_rank, 0); // should have been a |
4641 | | - // vector<bool>, but it is |
4642 | | - // optimized, so tie won't |
4643 | | - // work. Hence using size_t |
4644 | | - for (size_t dim_idx = 0; dim_idx < begin_vec.size(); dim_idx++) { |
4645 | | - std::tie(ng_begin_vec[dim_idx], ng_end_vec[dim_idx], ng_stride_vec[dim_idx], |
4646 | | - ng_needs_reversal[dim_idx]) = |
4647 | | - tf_to_ng(begin_vec[dim_idx], end_vec[dim_idx], stride_vec[dim_idx], |
4648 | | - dim_vec[dim_idx], extract_bit(tf_begin_mask, dim_idx), |
4649 | | - extract_bit(tf_end_mask, dim_idx), |
4650 | | - extract_bit(tf_shrink_axis_mask, dim_idx)); |
4651 | | - } |
4652 | | - |
4653 | | - // filter out negative stride dimensions |
4654 | | - vector<size_t> neg_strides; |
4655 | | - for (size_t dim_idx = 0; dim_idx < in_rank; dim_idx++) { |
4656 | | - if (ng_needs_reversal[dim_idx]) { |
4657 | | - neg_strides.push_back(dim_idx); |
| 4468 | + std::vector<int64_t> begin_vec_longint(begin_vec.begin(), begin_vec.end()); |
| 4469 | + std::vector<int64_t> end_vec_longint(end_vec.begin(), end_vec.end()); |
| 4470 | + std::vector<int64_t> stride_vec_longint(stride_vec.begin(), stride_vec.end()); |
| 4471 | + |
| 4472 | + NGRAPH_VLOG(4) << "Arguments to make_slice_plan: Input shape: " << input_shape |
| 4473 | + << ", begin vector: " << ng::join(begin_vec_longint) |
| 4474 | + << ", end vector: " << ng::join(end_vec_longint) |
| 4475 | + << ", stride vector: " << ng::join(stride_vec_longint) |
| 4476 | + << ", begin mask: " << tf_begin_mask |
| 4477 | + << ", end mask: " << tf_end_mask |
| 4478 | + << ", new axis mask: " << tf_new_axis_mask |
| 4479 | + << ", shrink axis mask: " << tf_shrink_axis_mask |
| 4480 | + << ", ellipsis mask: " << tf_ellipsis_mask; |
| 4481 | + |
| 4482 | + auto in_rank = ng_input->get_shape().size(); |
| 4483 | + if (tf_new_axis_mask == 0) { |
| 4484 | + if (begin_vec_longint.size() > in_rank) { |
| 4485 | + return errors::InvalidArgument("Index out of range using input dim ", |
| 4486 | + begin_vec_longint.size(), |
| 4487 | + "; input has only ", in_rank, " dims"); |
4658 | 4488 | } |
4659 | 4489 | } |
4660 | 4490 |
|
4661 | | - // atleast one stride was negative, in which case reverse the input |
4662 | | - if (neg_strides.size() > 0) |
4663 | | - ng_input = |
4664 | | - ConstructNgNode<ng::op::Reverse>(op->name(), ng_input, neg_strides); |
4665 | | - NGRAPH_VLOG(3) << "NG Lower Vector " << ng::join(ng_begin_vec); |
4666 | | - NGRAPH_VLOG(3) << "NG End Vector " << ng::join(ng_end_vec); |
4667 | | - NGRAPH_VLOG(3) << "NG Stride Vector " << ng::join(ng_stride_vec); |
4668 | | - NGRAPH_VLOG(3) << "NG Needs Reversal: " << ng::join(ng_needs_reversal); |
4669 | | - |
4670 | | - std::shared_ptr<ng::Node> ng_strided_slice = ConstructNgNode<ng::op::Slice>( |
4671 | | - op->name(), ng_input, ng_begin_vec, ng_end_vec, ng_stride_vec); |
4672 | | - |
4673 | | - if (tf_shrink_axis_mask) { |
4674 | | - int64 shrink_axis_mask = tf_shrink_axis_mask; |
4675 | | - vector<size_t> output_shape; |
4676 | | - |
4677 | | - // Note: do not use rank instead of ng_begin_vec.size() |
4678 | | - // since ng_begin_vec.size() can be less than rank, and |
4679 | | - // shrink_mask will have atmost ng_begin_vec.size() elements |
4680 | | - for (size_t i = 0; i < ng_begin_vec.size(); i++) { |
4681 | | - if ((shrink_axis_mask & 1) != 1) { |
4682 | | - output_shape.push_back(ng_end_vec[i] - ng_begin_vec[i]); |
4683 | | - } else { |
4684 | | - // TODO: must it equal 1 or can it be 0 too? |
4685 | | - if (ng_end_vec[i] - ng_begin_vec[i] > 1) |
4686 | | - return errors::InvalidArgument( |
4687 | | - "Trying to shrink specification ", i, |
4688 | | - "where tf begin, end, strides are: ", begin_vec[i], ":", |
4689 | | - end_vec[i], ":", stride_vec[i], |
4690 | | - ". nGraph begin, end, stride are: ", ng_begin_vec[i], ":", |
4691 | | - ng_end_vec[i], ":", ng_stride_vec[i], |
4692 | | - ". nGraph's begin and end have difference greater than 1"); |
4693 | | - } |
4694 | | - shrink_axis_mask >>= 1; |
4695 | | - } |
| 4491 | + auto sp = ng::make_slice_plan( |
| 4492 | + input_shape, begin_vec_longint, end_vec_longint, stride_vec_longint, |
| 4493 | + convert_mask_to_axes(tf_begin_mask), convert_mask_to_axes(tf_end_mask), |
| 4494 | + convert_mask_to_axes(tf_new_axis_mask), |
| 4495 | + convert_mask_to_axes(tf_shrink_axis_mask), |
| 4496 | + convert_mask_to_axes(tf_ellipsis_mask)); |
| 4497 | + |
| 4498 | + NGRAPH_VLOG(4) << "Return values of make_slice_plan: begin: " |
| 4499 | + << ng::join(sp.begins) << ", end: " << ng::join(sp.ends) |
| 4500 | + << ", stride: " << ng::join(sp.strides) |
| 4501 | + << ", reshape input shape: " << sp.reshape_in_shape |
| 4502 | + << ", reshape output shape: " << sp.reshape_out_shape |
| 4503 | + << ", reverse axis: " << sp.reverse_axes; |
| 4504 | + |
| 4505 | + // To handle cases like x[2:2], where shape(x) = [1], |
| 4506 | + // TF returns shape = [0], empty vector |
| 4507 | + // make_slice_plan returns begin=2, end=2, but that is > 1 |
| 4508 | + // So must clamp them |
| 4509 | + // Another example: |
| 4510 | + // for dimension 3, Also 2:3:-1 gives 4:4, which will also fail if we try to |
| 4511 | + // construct slice. So must clamp to 2:2 etc |
| 4512 | + |
| 4513 | + auto clamp = [](int64_t x, int64_t min, int64_t max) { |
| 4514 | + return x > max ? max : (x < min ? min : x); |
| 4515 | + }; |
| 4516 | + for (int i = 0; i < sp.begins.size(); i++) { |
| 4517 | + sp.begins[i] = clamp(sp.begins[i], 0, input_shape[i]); |
| 4518 | + sp.ends[i] = clamp(sp.ends[i], 0, input_shape[i]); |
| 4519 | + } |
4696 | 4520 |
|
4697 | | - NGRAPH_VLOG(3) << "Shrink axis mask " << tf_shrink_axis_mask; |
4698 | | - ng::Shape ng_final_shape(output_shape); |
4699 | | - ng::AxisVector ng_axis_order(input_shape.size()); |
| 4521 | + // Need to convert int64_t to size_t |
| 4522 | + std::vector<size_t> sp_begins(sp.begins.begin(), sp.begins.end()); |
| 4523 | + std::vector<size_t> sp_ends(sp.ends.begin(), sp.ends.end()); |
| 4524 | + std::vector<size_t> sp_strides(sp.strides.begin(), sp.strides.end()); |
| 4525 | + |
| 4526 | + shared_ptr<ng::Node> ng_result = ConstructNgNode<ng::op::Slice>( |
| 4527 | + op->name(), ng_input, sp_begins, sp_ends, sp_strides); |
| 4528 | + |
| 4529 | + if (sp.reshape_in_shape != sp.reshape_out_shape) { |
| 4530 | + ng::Shape ng_out_shape(sp.reshape_out_shape); |
| 4531 | + ng::AxisVector ng_axis_order(sp.reshape_in_shape.size()); |
| 4532 | + // std::iota Fills the range [first, last) with sequentially increasing |
| 4533 | + // values, |
| 4534 | + // starting with value and repetitively evaluating ++value |
4700 | 4535 | std::iota(ng_axis_order.begin(), ng_axis_order.end(), 0); |
4701 | 4536 |
|
4702 | | - NGRAPH_VLOG(3) << " Output shape " << ng::join(output_shape); |
| 4537 | + NGRAPH_VLOG(3) << " Output shape " << ng::join(ng_out_shape); |
4703 | 4538 | NGRAPH_VLOG(3) << " NG axis order " << ng::join(ng_axis_order); |
4704 | 4539 |
|
4705 | | - ng_strided_slice = ConstructNgNode<ng::op::Reshape>( |
4706 | | - op->name(), ng_strided_slice, ng_axis_order, ng_final_shape); |
| 4540 | + ng_result = ConstructNgNode<ng::op::Reshape>(op->name(), ng_result, |
| 4541 | + ng_axis_order, ng_out_shape); |
4707 | 4542 | } |
4708 | 4543 |
|
4709 | | - // TODO: assert size in this dim was 1 |
4710 | | - // TODO: assert new_axis_mask and tf_shrink_axis_mask are not set at the |
4711 | | - // same |
4712 | | - // time? |
4713 | | - // TODO: tf_new_axis_mask can exceed rank |
| 4544 | + if (!sp.reverse_axes.empty()) { |
| 4545 | + ng_result = ConstructNgNode<ng::op::Reverse>(op->name(), ng_result, |
| 4546 | + sp.reverse_axes); |
| 4547 | + } |
4714 | 4548 |
|
4715 | | - SaveNgOp(ng_op_map, op->name(), ng_strided_slice); |
| 4549 | + SaveNgOp(ng_op_map, op->name(), ng_result); |
4716 | 4550 | return Status::OK(); |
4717 | 4551 | } |
4718 | 4552 |
|
|
0 commit comments