Skip to content

Commit 8858d5f

Browse files
committed
builder: aggressively pointercast/bitcast to paper over opaque pointers.
1 parent acf8506 commit 8858d5f

File tree

13 files changed

+405
-268
lines changed

13 files changed

+405
-268
lines changed

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

+335-212
Large diffs are not rendered by default.

crates/rustc_codegen_spirv/src/builder/mod.rs

+11-13
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
102102
self.current_span.unwrap_or(DUMMY_SP)
103103
}
104104

105+
// HACK(eddyb) like the `CodegenCx` method but with `self.span()` awareness.
106+
pub fn type_ptr_to(&self, ty: Word) -> Word {
107+
SpirvType::Pointer { pointee: ty }.def(self.span(), self)
108+
}
109+
105110
// Given an ID, check if it's defined by an OpAccessChain, and if it is, return its ptr/indices
106111
fn find_access_chain(&self, id: spirv::Word) -> Option<(spirv::Word, Vec<spirv::Word>)> {
107112
let emit = self.emit();
@@ -130,20 +135,16 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
130135
indices: &[SpirvValue],
131136
is_inbounds: bool,
132137
) -> SpirvValue {
138+
// HACK(eddyb) temporary workaround for untyped pointers upstream.
139+
// FIXME(eddyb) replace with untyped memory SPIR-V + `qptr` or similar.
140+
let ptr = self.pointercast(ptr, self.type_ptr_to(ty));
141+
133142
// The first index is an offset to the pointer, the rest are actual members.
134143
// https://llvm.org/docs/GetElementPtr.html
135144
// "An OpAccessChain instruction is the equivalent of an LLVM getelementptr instruction where the first index element is zero."
136145
// https://github.com/gpuweb/gpuweb/issues/33
137146
let mut result_indices = Vec::with_capacity(indices.len() - 1);
138-
let mut result_pointee_type = match self.lookup_type(ptr.ty) {
139-
SpirvType::Pointer { pointee } => {
140-
assert_ty_eq!(self, ty, pointee);
141-
pointee
142-
}
143-
other_type => self.fatal(format!(
144-
"GEP first deref not implemented for type {other_type:?}"
145-
)),
146-
};
147+
let mut result_pointee_type = ty;
147148
for index in indices.iter().cloned().skip(1) {
148149
result_indices.push(index.def(self));
149150
result_pointee_type = match self.lookup_type(result_pointee_type) {
@@ -154,10 +155,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
154155
)),
155156
};
156157
}
157-
let result_type = SpirvType::Pointer {
158-
pointee: result_pointee_type,
159-
}
160-
.def(self.span(), self);
158+
let result_type = self.type_ptr_to(result_pointee_type);
161159

162160
let ptr_id = ptr.def(self);
163161
if let Some((original_ptr, mut original_indices)) = self.find_access_chain(ptr_id) {

crates/rustc_codegen_spirv/src/builder_spirv.rs

+17-5
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ pub enum SpirvValueKind {
5858
/// Pointer value being cast.
5959
original_ptr: Word,
6060

61-
/// Pointee type of `original_ptr`.
62-
original_pointee_ty: Word,
61+
/// Pointer type of `original_ptr`.
62+
original_ptr_ty: Word,
6363

6464
/// Result ID for the `OpBitcast` instruction representing the cast,
6565
/// to attach zombies to.
@@ -77,6 +77,18 @@ pub struct SpirvValue {
7777
}
7878

7979
impl SpirvValue {
80+
pub fn strip_ptrcasts(self) -> Self {
81+
match self.kind {
82+
SpirvValueKind::LogicalPtrCast {
83+
original_ptr,
84+
original_ptr_ty,
85+
bitcast_result_id: _,
86+
} => original_ptr.with_type(original_ptr_ty),
87+
88+
_ => self,
89+
}
90+
}
91+
8092
pub fn const_fold_load(self, cx: &CodegenCx<'_>) -> Option<Self> {
8193
match self.kind {
8294
SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => {
@@ -173,17 +185,17 @@ impl SpirvValue {
173185

174186
SpirvValueKind::LogicalPtrCast {
175187
original_ptr: _,
176-
original_pointee_ty,
188+
original_ptr_ty,
177189
bitcast_result_id,
178190
} => {
179191
cx.zombie_with_span(
180192
bitcast_result_id,
181193
span,
182194
&format!(
183195
"cannot cast between pointer types\
184-
\nfrom `*{}`\
196+
\nfrom `{}`\
185197
\n to `{}`",
186-
cx.debug_type(original_pointee_ty),
198+
cx.debug_type(original_ptr_ty),
187199
cx.debug_type(self.ty)
188200
),
189201
);

crates/rustc_codegen_spirv/src/codegen_cx/constant.rs

+9-8
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,8 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
370370
self.builder.lookup_const_by_id(pointee)
371371
{
372372
if let SpirvType::Pointer { pointee } = self.lookup_type(ty) {
373-
let init = self.create_const_alloc(alloc, pointee);
373+
let mut offset = Size::ZERO;
374+
let init = self.read_from_const_alloc(alloc, &mut offset, pointee);
374375
return self.static_addr_of(init, alloc.inner().align, None);
375376
}
376377
}
@@ -422,7 +423,7 @@ impl<'tcx> CodegenCx<'tcx> {
422423
// alloc.len()
423424
// );
424425
let mut offset = Size::ZERO;
425-
let result = self.create_const_alloc2(alloc, &mut offset, ty);
426+
let result = self.read_from_const_alloc(alloc, &mut offset, ty);
426427
assert_eq!(
427428
offset.bytes_usize(),
428429
alloc.inner().len(),
@@ -432,7 +433,7 @@ impl<'tcx> CodegenCx<'tcx> {
432433
result
433434
}
434435

435-
fn create_const_alloc2(
436+
fn read_from_const_alloc(
436437
&self,
437438
alloc: ConstAllocation<'tcx>,
438439
offset: &mut Size,
@@ -515,7 +516,7 @@ impl<'tcx> CodegenCx<'tcx> {
515516
let total_offset_start = base + field_offset;
516517
let mut total_offset_end = total_offset_start;
517518
values.push(
518-
self.create_const_alloc2(alloc, &mut total_offset_end, ty)
519+
self.read_from_const_alloc(alloc, &mut total_offset_end, ty)
519520
.def_cx(self),
520521
);
521522
occupied_spaces.push(total_offset_start..total_offset_end);
@@ -534,7 +535,7 @@ impl<'tcx> CodegenCx<'tcx> {
534535
SpirvType::Array { element, count } => {
535536
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
536537
let values = (0..count).map(|_| {
537-
self.create_const_alloc2(alloc, offset, element)
538+
self.read_from_const_alloc(alloc, offset, element)
538539
.def_cx(self)
539540
});
540541
self.constant_composite(ty, values)
@@ -545,7 +546,7 @@ impl<'tcx> CodegenCx<'tcx> {
545546
.expect("create_const_alloc: Vectors must be sized");
546547
let final_offset = *offset + total_size;
547548
let values = (0..count).map(|_| {
548-
self.create_const_alloc2(alloc, offset, element)
549+
self.read_from_const_alloc(alloc, offset, element)
549550
.def_cx(self)
550551
});
551552
let result = self.constant_composite(ty, values);
@@ -560,7 +561,7 @@ impl<'tcx> CodegenCx<'tcx> {
560561
.expect("create_const_alloc: Matrices must be sized");
561562
let final_offset = *offset + total_size;
562563
let values = (0..count).map(|_| {
563-
self.create_const_alloc2(alloc, offset, element)
564+
self.read_from_const_alloc(alloc, offset, element)
564565
.def_cx(self)
565566
});
566567
let result = self.constant_composite(ty, values);
@@ -573,7 +574,7 @@ impl<'tcx> CodegenCx<'tcx> {
573574
let mut values = Vec::new();
574575
while offset.bytes_usize() != alloc.inner().len() {
575576
values.push(
576-
self.create_const_alloc2(alloc, offset, element)
577+
self.read_from_const_alloc(alloc, offset, element)
577578
.def_cx(self),
578579
);
579580
}

crates/rustc_codegen_spirv/src/codegen_cx/type_.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> {
194194
align,
195195
size,
196196
field_types: els,
197-
field_offsets: &field_offsets.as_slice(),
197+
field_offsets: &field_offsets,
198198
field_names: None,
199199
}
200200
.def(DUMMY_SP, self)

crates/rustc_codegen_spirv/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
//! [`spirv-tools`]: https://embarkstudios.github.io/rust-gpu/api/spirv_tools
1717
//! [`spirv-tools-sys`]: https://embarkstudios.github.io/rust-gpu/api/spirv_tools_sys
1818
#![feature(rustc_private)]
19+
#![feature(array_methods)]
1920
#![feature(assert_matches)]
2021
#![feature(result_flattening)]
2122
#![feature(lint_reasons)]

crates/spirv-std/macros/src/lib.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -740,10 +740,14 @@ impl SampleImplRewriter {
740740
fn add_regs(&self, t: &mut Vec<TokenTree>) {
741741
for i in 0..SAMPLE_PARAM_COUNT {
742742
if self.0 & (1 << i) != 0 {
743+
// HACK(eddyb) the extra `{...}` force the pointers to be to
744+
// fresh variables holding value copies, instead of the originals,
745+
// allowing `OpLoad _` inference to pick the appropriate type.
743746
let s = if is_grad(i) {
744-
String::from("grad_x=in(reg) &params.grad.0.0,grad_y=in(reg) &params.grad.0.1,")
747+
"grad_x=in(reg) &{params.grad.0.0},grad_y=in(reg) &{params.grad.0.1},"
748+
.to_string()
745749
} else {
746-
format!("{0} = in(reg) &params.{0}.0,", SAMPLE_PARAM_NAMES[i])
750+
format!("{0} = in(reg) &{{params.{0}.0}},", SAMPLE_PARAM_NAMES[i])
747751
};
748752
let ts: proc_macro2::TokenStream = s.parse().unwrap();
749753
t.extend(ts);

tests/ui/arch/debug_printf_type_checking.stderr

+17-19
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ help: the return type of this call is `u32` due to the type of the argument pass
7575
| |
7676
| this argument influences the return type of `spirv_std`
7777
note: function defined here
78-
--> $SPIRV_STD_SRC/lib.rs:135:8
78+
--> $SPIRV_STD_SRC/lib.rs:136:8
7979
|
80-
135 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
80+
136 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
8181
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
8282
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)
8383
help: change the type of the numeric literal from `u32` to `f32`
@@ -102,9 +102,9 @@ help: the return type of this call is `f32` due to the type of the argument pass
102102
| |
103103
| this argument influences the return type of `spirv_std`
104104
note: function defined here
105-
--> $SPIRV_STD_SRC/lib.rs:135:8
105+
--> $SPIRV_STD_SRC/lib.rs:136:8
106106
|
107-
135 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
107+
136 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
108108
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
109109
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)
110110
help: change the type of the numeric literal from `f32` to `u32`
@@ -113,32 +113,30 @@ help: change the type of the numeric literal from `f32` to `u32`
113113
| ~~~
114114

115115
error[E0277]: the trait bound `{float}: Vector<f32, 2>` is not satisfied
116-
--> $DIR/debug_printf_type_checking.rs:23:31
116+
--> $DIR/debug_printf_type_checking.rs:23:9
117117
|
118118
23 | debug_printf!("%v2f", 11.0);
119-
| ----------------------^^^^-
120-
| | |
121-
| | the trait `Vector<f32, 2>` is not implemented for `{float}`
122-
| required by a bound introduced by this call
119+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Vector<f32, 2>` is not implemented for `{float}`
123120
|
124121
= help: the following other types implement trait `Vector<T, N>`:
122+
<Vec2 as Vector<f32, 2>>
123+
<Vec3 as Vector<f32, 3>>
124+
<Vec3A as Vector<f32, 3>>
125+
<Vec4 as Vector<f32, 4>>
125126
<DVec2 as Vector<f64, 2>>
126127
<DVec3 as Vector<f64, 3>>
127128
<DVec4 as Vector<f64, 4>>
128129
<IVec2 as Vector<i32, 2>>
129-
<IVec3 as Vector<i32, 3>>
130-
<IVec4 as Vector<i32, 4>>
131-
<UVec2 as Vector<u32, 2>>
132-
<UVec3 as Vector<u32, 3>>
133130
and 5 others
134131
note: required by a bound in `debug_printf_assert_is_vector`
135-
--> $SPIRV_STD_SRC/lib.rs:142:8
132+
--> $SPIRV_STD_SRC/lib.rs:143:8
136133
|
137-
140 | pub fn debug_printf_assert_is_vector<
134+
141 | pub fn debug_printf_assert_is_vector<
138135
| ----------------------------- required by a bound in this function
139-
141 | TY: crate::scalar::Scalar,
140-
142 | V: crate::vector::Vector<TY, SIZE>,
136+
142 | TY: crate::scalar::Scalar,
137+
143 | V: crate::vector::Vector<TY, SIZE>,
141138
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `debug_printf_assert_is_vector`
139+
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)
142140

143141
error[E0308]: mismatched types
144142
--> $DIR/debug_printf_type_checking.rs:24:29
@@ -157,9 +155,9 @@ help: the return type of this call is `Vec2` due to the type of the argument pas
157155
| |
158156
| this argument influences the return type of `spirv_std`
159157
note: function defined here
160-
--> $SPIRV_STD_SRC/lib.rs:135:8
158+
--> $SPIRV_STD_SRC/lib.rs:136:8
161159
|
162-
135 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
160+
136 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
163161
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
164162
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)
165163

tests/ui/dis/ptr_copy.normal.stderr

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
error: cannot memcpy dynamically sized data
2-
--> $CORE_SRC/intrinsics.rs:2767:9
2+
--> $CORE_SRC/intrinsics.rs:2771:9
33
|
4-
2767 | copy(src, dst, count)
4+
2771 | copy(src, dst, count)
55
| ^^^^^^^^^^^^^^^^^^^^^
66
|
77
note: used from within `core::intrinsics::copy::<f32>`
8-
--> $CORE_SRC/intrinsics.rs:2753:21
8+
--> $CORE_SRC/intrinsics.rs:2757:21
99
|
10-
2753 | pub const unsafe fn copy<T>(src: *const T, dst: *mut T, count: usize) {
10+
2757 | pub const unsafe fn copy<T>(src: *const T, dst: *mut T, count: usize) {
1111
| ^^^^
1212
note: called by `ptr_copy::copy_via_raw_ptr`
1313
--> $DIR/ptr_copy.rs:28:18

tests/ui/dis/ptr_read.stderr

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
%4 = OpFunctionParameter %5
33
%6 = OpFunctionParameter %5
44
%7 = OpLabel
5-
OpLine %8 1179 8
5+
OpLine %8 1180 8
66
%9 = OpLoad %10 %4
77
OpLine %11 7 13
88
OpStore %6 %9

tests/ui/dis/ptr_read_method.stderr

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
%4 = OpFunctionParameter %5
33
%6 = OpFunctionParameter %5
44
%7 = OpLabel
5-
OpLine %8 1179 8
5+
OpLine %8 1180 8
66
%9 = OpLoad %10 %4
77
OpLine %11 7 13
88
OpStore %6 %9

tests/ui/dis/ptr_write.stderr

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
%7 = OpLabel
55
OpLine %8 7 35
66
%9 = OpLoad %10 %4
7-
OpLine %11 1377 8
7+
OpLine %11 1379 8
88
OpStore %6 %9
99
OpNoLine
1010
OpReturn

tests/ui/dis/ptr_write_method.stderr

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
%7 = OpLabel
55
OpLine %8 7 37
66
%9 = OpLoad %10 %4
7-
OpLine %11 1377 8
7+
OpLine %11 1379 8
88
OpStore %6 %9
99
OpNoLine
1010
OpReturn

0 commit comments

Comments
 (0)