Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga] Builtin function database, for automatic conversions #6833

Draft
wants to merge 9 commits into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 298 additions & 0 deletions naga/src/common/wgsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,301 @@ impl StandardFilterableTriggeringRule {
}
}
}

impl crate::MathFunction {
pub fn to_wgsl(self) -> &'static str {
use crate::MathFunction as Mf;

match self {
Mf::Abs => "abs",
Mf::Min => "min",
Mf::Max => "max",
Mf::Clamp => "clamp",
Mf::Saturate => "saturate",
Mf::Cos => "cos",
Mf::Cosh => "cosh",
Mf::Sin => "sin",
Mf::Sinh => "sinh",
Mf::Tan => "tan",
Mf::Tanh => "tanh",
Mf::Acos => "acos",
Mf::Asin => "asin",
Mf::Atan => "atan",
Mf::Atan2 => "atan2",
Mf::Asinh => "asinh",
Mf::Acosh => "acosh",
Mf::Atanh => "atanh",
Mf::Radians => "radians",
Mf::Degrees => "degrees",
Mf::Ceil => "ceil",
Mf::Floor => "floor",
Mf::Round => "round",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Ldexp => "ldexp",
Mf::Exp => "exp",
Mf::Exp2 => "exp2",
Mf::Log => "log",
Mf::Log2 => "log2",
Mf::Pow => "pow",
Mf::Dot => "dot",
Mf::Cross => "cross",
Mf::Distance => "distance",
Mf::Length => "length",
Mf::Normalize => "normalize",
Mf::FaceForward => "faceForward",
Mf::Reflect => "reflect",
Mf::Refract => "refract",
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Mix => "mix",
Mf::Step => "step",
Mf::SmoothStep => "smoothstep",
Mf::Sqrt => "sqrt",
Mf::InverseSqrt => "inverseSqrt",
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
Mf::QuantizeToF16 => "quantizeToF16",
Mf::CountTrailingZeros => "countTrailingZeros",
Mf::CountLeadingZeros => "countLeadingZeros",
Mf::CountOneBits => "countOneBits",
Mf::ReverseBits => "reverseBits",
Mf::ExtractBits => "extractBits",
Mf::InsertBits => "insertBits",
Mf::FirstTrailingBit => "firstTrailingBit",
Mf::FirstLeadingBit => "firstLeadingBit",
Mf::Pack4x8snorm => "pack4x8snorm",
Mf::Pack4x8unorm => "pack4x8unorm",
Mf::Pack2x16snorm => "pack2x16snorm",
Mf::Pack2x16unorm => "pack2x16unorm",
Mf::Pack2x16float => "pack2x16float",
Mf::Pack4xI8 => "pack4xI8",
Mf::Pack4xU8 => "pack4xU8",
Mf::Unpack4x8snorm => "unpack4x8snorm",
Mf::Unpack4x8unorm => "unpack4x8unorm",
Mf::Unpack2x16snorm => "unpack2x16snorm",
Mf::Unpack2x16unorm => "unpack2x16unorm",
Mf::Unpack2x16float => "unpack2x16float",
Mf::Unpack4xI8 => "unpack4xI8",
Mf::Unpack4xU8 => "unpack4xU8",
Mf::Inverse => "{matrix inverse}",
Mf::Outer => "{vector outer product}",
}
}
}

impl crate::BuiltIn {
pub fn to_wgsl(self) -> &'static str {
match self {
crate::BuiltIn::Position { invariant: true } => "@position @invariant",
crate::BuiltIn::Position { invariant: false } => "@position",
crate::BuiltIn::ViewIndex => "view_index",
crate::BuiltIn::BaseInstance => "{BaseInstance}",
crate::BuiltIn::BaseVertex => "{BaseVertex}",
crate::BuiltIn::ClipDistance => "{ClipDistance}",
crate::BuiltIn::CullDistance => "{CullDistance}",
crate::BuiltIn::InstanceIndex => "instance_index",
crate::BuiltIn::PointSize => "{PointSize}",
crate::BuiltIn::VertexIndex => "vertex_index",
crate::BuiltIn::DrawID => "{DrawId}",
crate::BuiltIn::FragDepth => "frag_depth",
crate::BuiltIn::PointCoord => "{PointCoord}",
crate::BuiltIn::FrontFacing => "front_facing",
crate::BuiltIn::PrimitiveIndex => "primitive_index",
crate::BuiltIn::SampleIndex => "sample_index",
crate::BuiltIn::SampleMask => "sample_mask",
crate::BuiltIn::GlobalInvocationId => "global_invocation_id",
crate::BuiltIn::LocalInvocationId => "local_invocation_id",
crate::BuiltIn::LocalInvocationIndex => "local_invocation_index",
crate::BuiltIn::WorkGroupId => "workgroup_id",
crate::BuiltIn::WorkGroupSize => "{WorkGroupSize}",
crate::BuiltIn::NumWorkGroups => "num_workgroups",
crate::BuiltIn::NumSubgroups => "num_subgroups",
crate::BuiltIn::SubgroupId => "{SubgroupId}",
crate::BuiltIn::SubgroupSize => "subgroup_size",
crate::BuiltIn::SubgroupInvocationId => "subgroup_invocation_id",
}
}
}

impl crate::Interpolation {
pub fn to_wgsl(self) -> &'static str {
match self {
crate::Interpolation::Perspective => "perspective",
crate::Interpolation::Linear => "linear",
crate::Interpolation::Flat => "flat",
}
}
}

impl crate::Sampling {
pub fn to_wgsl(self) -> &'static str {
match self {
crate::Sampling::Center => "center",
crate::Sampling::Centroid => "centroid",
crate::Sampling::Sample => "sample",
crate::Sampling::First => "first",
crate::Sampling::Either => "either",
}
}
}

pub struct Wgslish<T>(pub T);

impl Display for Wgslish<&crate::TypeInner> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match *self.0 {
crate::TypeInner::Scalar(scalar) => Wgslish(scalar).fmt(f),
crate::TypeInner::Vector { size, scalar } => {
write!(f, "vec{}<{}>", size as u8, Wgslish(scalar))
}
crate::TypeInner::Matrix {
columns,
rows,
scalar,
} => {
write!(
f,
"mat{}x{}<{}>",
columns as u8,
rows as u8,
Wgslish(scalar)
)
}
crate::TypeInner::Atomic(scalar) => {
write!(f, "atomic<{}>", Wgslish(scalar))
}
crate::TypeInner::Pointer { base, space } => {
write!(f, "ptr<{}, {base:?}>", Wgslish(space))
}
crate::TypeInner::ValuePointer {
size,
scalar,
space,
} => {
write!(f, "ptr<{}, ", Wgslish(space))?;
match size {
Some(size) => write!(f, "vec{}<{}>", size as u8, Wgslish(scalar))?,
None => Wgslish(scalar).fmt(f)?,
}
f.write_str(">")
}
crate::TypeInner::Array { base, size, stride } => {
write!(f, "@stride({stride}) array<{base:?}")?;
match size {
crate::ArraySize::Constant(non_zero) => write!(f, ", {non_zero}>"),
crate::ArraySize::Pending(pending_array_size) => match pending_array_size {
crate::PendingArraySize::Expression(handle) => {
write!(f, "expression {handle:?}")
}
crate::PendingArraySize::Override(handle) => {
write!(f, "override {handle:?}")
}
},
crate::ArraySize::Dynamic => f.write_str(">"),
}
}
crate::TypeInner::Struct { ref members, span } => {
write!(f, "@span({span}) struct {{ ")?;
for (i, member) in members.iter().enumerate() {
if i != 0 {
f.write_str(", ")?;
}
write!(f, "@offset({}) ", member.offset)?;
if let Some(ref binding) = member.binding {
Wgslish(binding).fmt(f)?;
}
write!(
f,
"{}: {:?}",
member.name.as_deref().unwrap_or("<anonymous>"),
member.ty
)?;
}
f.write_str("}")
}
crate::TypeInner::Image {
dim: _,
arrayed: _,
class: _,
} => todo!(),
crate::TypeInner::Sampler { comparison: _ } => todo!(),
crate::TypeInner::AccelerationStructure => todo!(),
crate::TypeInner::RayQuery => todo!(),
crate::TypeInner::BindingArray { base, size } => {
write!(f, "array<{base:?}")?;
match size {
crate::ArraySize::Constant(non_zero) => write!(f, ", {non_zero}>"),
crate::ArraySize::Pending(pending_array_size) => match pending_array_size {
crate::PendingArraySize::Expression(handle) => {
write!(f, "expression {handle:?}")
}
crate::PendingArraySize::Override(handle) => {
write!(f, "override {handle:?}")
}
},
crate::ArraySize::Dynamic => f.write_str(">"),
}
}
}
}
}

impl Display for Wgslish<crate::Scalar> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let bits = self.0.width * 8;
match self.0.kind {
crate::ScalarKind::Sint => write!(f, "i{bits}"),
crate::ScalarKind::Uint => write!(f, "u{bits}"),
crate::ScalarKind::Float => write!(f, "f{bits}"),
crate::ScalarKind::Bool => f.write_str("bool"),
crate::ScalarKind::AbstractInt => f.write_str("{AbstractInt}"),
crate::ScalarKind::AbstractFloat => f.write_str("{AbstractFloat}"),
}
}
}

impl Display for Wgslish<crate::AddressSpace> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let wgsl = match self.0 {
crate::AddressSpace::Function => "function",
crate::AddressSpace::Private => "private",
crate::AddressSpace::WorkGroup => "workgroup",
crate::AddressSpace::Uniform => "uniform",
crate::AddressSpace::Storage { access } => {
return write!(f, "storage, {access:?}");
}
crate::AddressSpace::Handle => "handle",
crate::AddressSpace::PushConstant => "push_constant",
};
f.write_str(wgsl)
}
}

impl Display for Wgslish<&crate::Binding> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match *self.0 {
crate::Binding::BuiltIn(built_in) => f.write_str(built_in.to_wgsl()),
crate::Binding::Location {
location,
second_blend_source,
interpolation,
sampling,
} => {
write!(f, "@location({location})")?;
if second_blend_source {
f.write_str(" @second_blend_source")?;
}
if let Some(interpolation) = interpolation {
write!(f, " {}", interpolation.to_wgsl())?;
}
if let Some(sampling) = sampling {
write!(f, " {}", sampling.to_wgsl())?;
}
Ok(())
}
}
}
}
81 changes: 81 additions & 0 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,25 @@ pub(crate) enum Error<'a> {
expected: Range<u32>,
found: u32,
},
WrongArgumentCountForOverloads {
function: String,
call_span: Span,
arg_span: Span,
max_arguments: u32,
overloads: Vec<String>,
},
WrongArgumentType {
function: String,
call_span: Span,
arg_span: Span,
arg_index: u32,
found: String,
allowed: Vec<String>,
},
AmbiguousCall {
call_span: Span,
alternatives: Vec<String>,
},
FunctionReturnsVoid(Span),
FunctionMustUseUnused(Span),
FunctionMustUseReturnsVoid(Span, Span),
Expand Down Expand Up @@ -809,6 +828,68 @@ impl<'a> Error<'a> {
labels: vec![(span, "wrong number of arguments".into())],
notes: vec![],
},
Error::WrongArgumentCountForOverloads {
ref function,
call_span,
arg_span,
max_arguments,
ref overloads,
} => {
let message = format!(
"For the preceding argument types, `{function}` accepts only {max_arguments} arguments"
);
let labels = vec![
(call_span, "This function call has too many arguments".into()),
(arg_span, "This is the first excess argument".into())
];
let mut notes = vec![
format!("These are the only overloads of `{function}` that could accept the preceding arguments:"),
];
notes.extend(overloads.iter().map(|o| format!("overload: {o}")));

ParseError { message, labels, notes }
}
Error::WrongArgumentType {
ref function,
call_span,
arg_span,
arg_index,
ref found,
ref allowed,
} => {
let message = format!(
"This call to `{function}` cannot accept a value of type `{found}` for argument #{}",
arg_index + 1,
);
let labels = vec![
(call_span, "The arguments to this function call have incorrect types".into()),
(arg_span, format!(
"This argument has type `{found}`",
).into())
];

let mut notes = vec![];
if arg_index > 0 {
notes.push("Given the types of the preceding arguments,".into());
notes.push(format!("the following types are allowed for argument #{}:", arg_index + 1));
} else {
notes.push("The following types are allowed for the first argument:".to_string());
};
notes.extend(allowed.iter().map(|ty| format!("allowed type: {ty}")));

ParseError { message, labels, notes }
},
Error::AmbiguousCall { call_span, ref alternatives } => {
let message = "Function call is ambiguous: more than one overload could apply".into();
let labels = vec![
(call_span, "More than one overload of this function could apply to these arguments".into()),
];
let mut notes = vec![
"All of the following overloads could apply, but no one overload is clearly preferable:".into()
];
notes.extend(alternatives.iter().map(|alt| format!("possible overload: {alt}")));
ParseError { message, labels, notes }
},
Error::FunctionReturnsVoid(span) => ParseError {
message: "function does not return any value".to_string(),
labels: vec![(span, "".into())],
Expand Down
Loading
Loading