Skip to content

Commit 31f4592

Browse files
eddybLegNeato
authored andcommitted
builder: handle ptr_add-like GEPs (introduced by rust-lang/rust#118991).
1 parent e422b7a commit 31f4592

File tree

2 files changed

+244
-131
lines changed

2 files changed

+244
-131
lines changed

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

+243-21
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use std::borrow::Cow;
3030
use std::cell::Cell;
3131
use std::convert::TryInto;
3232
use std::iter::{self, empty};
33+
use std::ops::RangeInclusive;
3334

3435
macro_rules! simple_op {
3536
(
@@ -412,9 +413,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
412413
// FIXME(eddyb) this isn't efficient, `recover_access_chain_from_offset`
413414
// could instead be doing all the extra digging itself.
414415
let mut indices = SmallVec::<[_; 8]>::new();
415-
while let Some((inner_indices, inner_ty)) =
416-
self.recover_access_chain_from_offset(leaf_ty, Size::ZERO, Some(size), None)
417-
{
416+
while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset(
417+
leaf_ty,
418+
Size::ZERO,
419+
Some(size)..=Some(size),
420+
None,
421+
) {
418422
indices.extend(inner_indices);
419423
leaf_ty = inner_ty;
420424
}
@@ -439,8 +443,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
439443
}
440444

441445
/// If possible, return the appropriate `OpAccessChain` indices for going
442-
/// from a pointer to `ty`, to a pointer to some leaf field/element of size
443-
/// `leaf_size` (and optionally type `leaf_ty`), while adding `offset` bytes.
446+
/// from a pointer to `ty`, to a pointer to some leaf field/element having
447+
/// a size that fits `leaf_size_range` (and, optionally, the type `leaf_ty`),
448+
/// while adding `offset` bytes.
444449
///
445450
/// That is, try to turn `((_: *T) as *u8).add(offset) as *Leaf` into a series
446451
/// of struct field and array/vector element accesses.
@@ -449,7 +454,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
449454
mut ty: <Self as BackendTypes>::Type,
450455
mut offset: Size,
451456
// FIXME(eddyb) using `None` for "unsized" is a pretty bad design.
452-
leaf_size_or_unsized: Option<Size>,
457+
leaf_size_or_unsized_range: RangeInclusive<Option<Size>>,
453458
leaf_ty: Option<<Self as BackendTypes>::Type>,
454459
) -> Option<(SmallVec<[u32; 8]>, <Self as BackendTypes>::Type)> {
455460
assert_ne!(Some(ty), leaf_ty);
@@ -460,7 +465,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
460465
Sized(Size),
461466
Unsized,
462467
}
463-
let leaf_size = leaf_size_or_unsized.map_or(MaybeSized::Unsized, MaybeSized::Sized);
468+
let leaf_size_range = {
469+
let r = leaf_size_or_unsized_range;
470+
let [start, end] =
471+
[r.start(), r.end()].map(|x| x.map_or(MaybeSized::Unsized, MaybeSized::Sized));
472+
start..=end
473+
};
464474

465475
// NOTE(eddyb) `ty` and `ty_kind`/`ty_size` should be kept in sync.
466476
let mut ty_kind = self.lookup_type(ty);
@@ -493,7 +503,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
493503
if MaybeSized::Sized(offset_in_field) < field_ty_size
494504
// If the field is a zero sized type, check the
495505
// expected size and type to get the correct entry
496-
|| offset_in_field == Size::ZERO && leaf_size == MaybeSized::Sized(Size::ZERO) && leaf_ty == Some(field_ty)
506+
|| offset_in_field == Size::ZERO
507+
&& leaf_size_range.contains(&MaybeSized::Sized(Size::ZERO)) && leaf_ty == Some(field_ty)
497508
{
498509
Some((i, field_ty, field_ty_kind, field_ty_size, offset_in_field))
499510
} else {
@@ -525,19 +536,211 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
525536
}
526537

527538
// Avoid digging beyond the point the leaf could actually fit.
528-
if ty_size < leaf_size {
539+
if ty_size < *leaf_size_range.start() {
529540
return None;
530541
}
531542

532543
if offset == Size::ZERO
533-
&& ty_size == leaf_size
544+
&& leaf_size_range.contains(&ty_size)
534545
&& leaf_ty.map_or(true, |leaf_ty| leaf_ty == ty)
535546
{
536547
return Some((indices, ty));
537548
}
538549
}
539550
}
540551

552+
fn maybe_inbounds_gep(
553+
&mut self,
554+
ty: Word,
555+
ptr: SpirvValue,
556+
combined_indices: &[SpirvValue],
557+
is_inbounds: bool,
558+
) -> SpirvValue {
559+
let (&ptr_base_index, indices) = combined_indices.split_first().unwrap();
560+
561+
// The first index is an offset to the pointer, the rest are actual members.
562+
// https://llvm.org/docs/GetElementPtr.html
563+
// "An OpAccessChain instruction is the equivalent of an LLVM getelementptr instruction where the first index element is zero."
564+
// https://github.com/gpuweb/gpuweb/issues/33
565+
let mut result_pointee_type = ty;
566+
let indices: Vec<_> = indices
567+
.iter()
568+
.map(|index| {
569+
result_pointee_type = match self.lookup_type(result_pointee_type) {
570+
SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => {
571+
element
572+
}
573+
_ => self.fatal(format!(
574+
"GEP not implemented for type {}",
575+
self.debug_type(result_pointee_type)
576+
)),
577+
};
578+
index.def(self)
579+
})
580+
.collect();
581+
582+
// Special-case field accesses through a `pointercast`, to accesss the
583+
// right field in the original type, for the `Logical` addressing model.
584+
let ptr = ptr.strip_ptrcasts();
585+
let ptr_id = ptr.def(self);
586+
let original_pointee_ty = match self.lookup_type(ptr.ty) {
587+
SpirvType::Pointer { pointee } => pointee,
588+
other => self.fatal(format!("gep called on non-pointer type: {other:?}")),
589+
};
590+
591+
// HACK(eddyb) `struct_gep` itself is falling out of use, as it's being
592+
// replaced upstream by `ptr_add` (aka `inbounds_gep` with byte offsets).
593+
//
594+
// FIXME(eddyb) get rid of everything other than:
595+
// - constant byte offset (`ptr_add`?)
596+
// - dynamic indexing of a single array
597+
let const_ptr_offset = self
598+
.builder
599+
.lookup_const_u64(ptr_base_index)
600+
.and_then(|idx| Some(idx * self.lookup_type(ty).sizeof(self)?));
601+
if let Some(const_ptr_offset) = const_ptr_offset {
602+
if let Some((base_indices, base_pointee_ty)) = self.recover_access_chain_from_offset(
603+
original_pointee_ty,
604+
const_ptr_offset,
605+
Some(Size::ZERO)..=None,
606+
None,
607+
) {
608+
// FIXME(eddyb) this condition is pretty limiting, but
609+
// eventually it shouldn't matter if GEPs are going away.
610+
if ty == base_pointee_ty || indices.is_empty() {
611+
let result_pointee_type = if indices.is_empty() {
612+
base_pointee_ty
613+
} else {
614+
result_pointee_type
615+
};
616+
let indices = base_indices
617+
.into_iter()
618+
.map(|idx| self.constant_u32(self.span(), idx).def(self))
619+
.chain(indices)
620+
.collect();
621+
return self.emit_access_chain(
622+
self.type_ptr_to(result_pointee_type),
623+
ptr_id,
624+
None,
625+
indices,
626+
is_inbounds,
627+
);
628+
}
629+
}
630+
}
631+
632+
let result_type = self.type_ptr_to(result_pointee_type);
633+
634+
// Check if `ptr_id` is defined by an `OpAccessChain`, and if it is,
635+
// grab its base pointer and indices.
636+
//
637+
// FIXME(eddyb) this could get ridiculously expensive, at the very least
638+
// it could use `.rev()`, hoping the base pointer was recently defined?
639+
let maybe_original_access_chain = if ty == original_pointee_ty {
640+
let emit = self.emit();
641+
let module = emit.module_ref();
642+
let func = &module.functions[emit.selected_function().unwrap()];
643+
let base_ptr_and_combined_indices = func
644+
.all_inst_iter()
645+
.find(|inst| inst.result_id == Some(ptr_id))
646+
.and_then(|ptr_def_inst| {
647+
if matches!(
648+
ptr_def_inst.class.opcode,
649+
Op::AccessChain | Op::InBoundsAccessChain
650+
) {
651+
let base_ptr = ptr_def_inst.operands[0].unwrap_id_ref();
652+
let indices = ptr_def_inst.operands[1..]
653+
.iter()
654+
.map(|op| op.unwrap_id_ref())
655+
.collect::<Vec<_>>();
656+
Some((base_ptr, indices))
657+
} else {
658+
None
659+
}
660+
});
661+
base_ptr_and_combined_indices
662+
} else {
663+
None
664+
};
665+
if let Some((original_ptr, mut original_indices)) = maybe_original_access_chain {
666+
// Transform the following:
667+
// OpAccessChain original_ptr [a, b, c]
668+
// OpPtrAccessChain ptr base [d, e, f]
669+
// into
670+
// OpAccessChain original_ptr [a, b, c + base, d, e, f]
671+
// to remove the need for OpPtrAccessChain
672+
let last = original_indices.last_mut().unwrap();
673+
*last = self
674+
.add(last.with_type(ptr_base_index.ty), ptr_base_index)
675+
.def(self);
676+
original_indices.extend(indices);
677+
return self.emit_access_chain(
678+
result_type,
679+
original_ptr,
680+
None,
681+
original_indices,
682+
is_inbounds,
683+
);
684+
}
685+
686+
// HACK(eddyb) temporary workaround for untyped pointers upstream.
687+
// FIXME(eddyb) replace with untyped memory SPIR-V + `qptr` or similar.
688+
let ptr = self.pointercast(ptr, self.type_ptr_to(ty));
689+
let ptr_id = ptr.def(self);
690+
691+
self.emit_access_chain(
692+
result_type,
693+
ptr_id,
694+
Some(ptr_base_index),
695+
indices,
696+
is_inbounds,
697+
)
698+
}
699+
700+
fn emit_access_chain(
701+
&self,
702+
result_type: <Self as BackendTypes>::Type,
703+
pointer: Word,
704+
ptr_base_index: Option<SpirvValue>,
705+
indices: Vec<Word>,
706+
is_inbounds: bool,
707+
) -> SpirvValue {
708+
let mut emit = self.emit();
709+
710+
let non_zero_ptr_base_index =
711+
ptr_base_index.filter(|&idx| self.builder.lookup_const_u64(idx) != Some(0));
712+
if let Some(ptr_base_index) = non_zero_ptr_base_index {
713+
let result = if is_inbounds {
714+
emit.in_bounds_ptr_access_chain(
715+
result_type,
716+
None,
717+
pointer,
718+
ptr_base_index.def(self),
719+
indices,
720+
)
721+
} else {
722+
emit.ptr_access_chain(
723+
result_type,
724+
None,
725+
pointer,
726+
ptr_base_index.def(self),
727+
indices,
728+
)
729+
}
730+
.unwrap();
731+
self.zombie(result, "cannot offset a pointer to an arbitrary element");
732+
result
733+
} else {
734+
if is_inbounds {
735+
emit.in_bounds_access_chain(result_type, None, pointer, indices)
736+
} else {
737+
emit.access_chain(result_type, None, pointer, indices)
738+
}
739+
.unwrap()
740+
}
741+
.with_type(result_type)
742+
}
743+
541744
fn fptoint_sat(
542745
&mut self,
543746
signed: bool,
@@ -1361,7 +1564,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
13611564
}
13621565

13631566
fn gep(&mut self, ty: Self::Type, ptr: Self::Value, indices: &[Self::Value]) -> Self::Value {
1364-
self.gep_help(ty, ptr, indices, false)
1567+
self.maybe_inbounds_gep(ty, ptr, indices, false)
13651568
}
13661569

13671570
fn inbounds_gep(
@@ -1370,7 +1573,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
13701573
ptr: Self::Value,
13711574
indices: &[Self::Value],
13721575
) -> Self::Value {
1373-
self.gep_help(ty, ptr, indices, true)
1576+
self.maybe_inbounds_gep(ty, ptr, indices, true)
13741577
}
13751578

13761579
fn struct_gep(&mut self, ty: Self::Type, ptr: Self::Value, idx: u64) -> Self::Value {
@@ -1395,6 +1598,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
13951598
"struct_gep not on struct, array, or vector type: {other:?}, index {idx}"
13961599
)),
13971600
};
1601+
let result_pointee_size = self.lookup_type(result_pointee_type).sizeof(self);
13981602
let result_type = self.type_ptr_to(result_pointee_type);
13991603

14001604
// Special-case field accesses through a `pointercast`, to accesss the
@@ -1407,7 +1611,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
14071611
if let Some((indices, _)) = self.recover_access_chain_from_offset(
14081612
original_pointee_ty,
14091613
offset,
1410-
self.lookup_type(result_pointee_type).sizeof(self),
1614+
result_pointee_size..=result_pointee_size,
14111615
Some(result_pointee_type),
14121616
) {
14131617
let original_ptr = ptr.def(self);
@@ -1586,9 +1790,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
15861790
// FIXME(eddyb) this isn't efficient, `recover_access_chain_from_offset`
15871791
// could instead be doing all the extra digging itself.
15881792
let mut indices = SmallVec::<[_; 8]>::new();
1589-
while let Some((inner_indices, inner_ty)) =
1590-
self.recover_access_chain_from_offset(leaf_ty, Size::ZERO, Some(size), None)
1591-
{
1793+
while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset(
1794+
leaf_ty,
1795+
Size::ZERO,
1796+
Some(size)..=Some(size),
1797+
None,
1798+
) {
15921799
indices.extend(inner_indices);
15931800
leaf_ty = inner_ty;
15941801
}
@@ -1716,9 +1923,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
17161923
return self.const_bitcast(ptr, dest_ty);
17171924
}
17181925

1926+
if ptr.ty == dest_ty {
1927+
return ptr;
1928+
}
1929+
17191930
// Strip a previous `pointercast`, to reveal the original pointer type.
17201931
let ptr = ptr.strip_ptrcasts();
17211932

1933+
if ptr.ty == dest_ty {
1934+
return ptr;
1935+
}
1936+
17221937
let ptr_pointee = match self.lookup_type(ptr.ty) {
17231938
SpirvType::Pointer { pointee } => pointee,
17241939
other => self.fatal(format!(
@@ -1731,12 +1946,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
17311946
"pointercast called on non-pointer dest type: {other:?}"
17321947
)),
17331948
};
1734-
if ptr.ty == dest_ty {
1735-
ptr
1736-
} else if let Some((indices, _)) = self.recover_access_chain_from_offset(
1949+
let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);
1950+
1951+
if let Some((indices, _)) = self.recover_access_chain_from_offset(
17371952
ptr_pointee,
17381953
Size::ZERO,
1739-
self.lookup_type(dest_pointee).sizeof(self),
1954+
dest_pointee_size..=dest_pointee_size,
17401955
Some(dest_pointee),
17411956
) {
17421957
let indices = indices
@@ -2687,6 +2902,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
26872902
Store(ID, ID),
26882903
Load(ID, ID),
26892904
Call(ID, ID, SmallVec<[ID; 4]>),
2905+
2906+
// HACK(eddyb) this only exists for better error reporting,
2907+
// as `Result<Inst<...>, Op>` would only report one `Op`.
2908+
Unsupported(
2909+
// HACK(eddyb) only exists for `fmt::Debug` in case of error.
2910+
#[allow(dead_code)] Op,
2911+
),
26902912
}
26912913

26922914
let taken_inst_idx_range = Cell::new(func.blocks[block_idx].instructions.len())..;
@@ -2732,7 +2954,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
27322954
(Op::FunctionCall, Some(r), [f, args @ ..]) => {
27332955
Inst::Call(r, *f, args.iter().copied().collect())
27342956
}
2735-
_ => return None,
2957+
_ => Inst::Unsupported(inst.class.opcode),
27362958
},
27372959
)
27382960
});

0 commit comments

Comments
 (0)