File tree 1 file changed +12
-5
lines changed 1 file changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -65,13 +65,20 @@ using ::xla::StackFrameIndexProto;
65
65
66
66
Shape ResolveShapeIndex (const xla::ShapeProto& shape_proto,
67
67
absl::Span<const int64_t > shape_index) {
68
- if (shape_index.empty ()) return Shape (shape_proto);
69
68
// Choosing the last subshape to maintain historical behavior.
70
- int64_t i = shape_index.back ();
71
- if (i >= shape_proto.tuple_shapes_size ()) {
72
- return Shape (shape_proto);
69
+ const xla::ShapeProto* proto = &shape_proto;
70
+ if (!shape_index.empty ()) {
71
+ int64_t i = shape_index.back ();
72
+ if (i < shape_proto.tuple_shapes_size ()) {
73
+ proto = &shape_proto.tuple_shapes (i);
74
+ }
75
+ }
76
+ absl::StatusOr<Shape> shape = Shape::FromProto (*proto);
77
+ if (!shape.ok ()) {
78
+ LOG (DFATAL) << " Failed to resolve shape index: " << shape.status ();
79
+ return Shape ();
73
80
}
74
- return Shape (shape_proto. tuple_shapes (i)) ;
81
+ return *shape ;
75
82
}
76
83
77
84
std::string ShapeDescription (const Shape& shape) {
You can’t perform that action at this time.
0 commit comments