Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions crates/rustc_codegen_spirv/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use std::{env, fs, mem};
/// `cargo publish`. We need to figure out a way to do this properly, but let's hardcode it for now :/
//const REQUIRED_RUST_TOOLCHAIN: &str = include_str!("../../rust-toolchain.toml");
const REQUIRED_RUST_TOOLCHAIN: &str = r#"[toolchain]
channel = "nightly-2025-06-30"
channel = "nightly-2025-08-04"
components = ["rust-src", "rustc-dev", "llvm-tools"]
# commit_hash = 35f6036521777bdc0dcea1f980be4c192962a168"#;
# commit_hash = f34ba774c78ea32b7c40598b8ad23e75cdac42a6"#;

fn rustc_output(arg: &str) -> Result<String, Box<dyn Error>> {
let rustc = env::var("RUSTC").unwrap_or_else(|_| "rustc".into());
Expand Down Expand Up @@ -237,6 +237,14 @@ pub(super) fn elf_e_flags(architecture: Architecture, sess: &Session) -> u32 {",
src = src.replace("alloca(field.size,", "typed_alloca(llfield_ty,");
}

// HACK(fee1-dead): our backend type number doesn't always match the type of the value. Should fix?
if relative_path == Path::new("src/mir/rvalue.rs") {
src = src.replace(
"debug_assert_eq!(bx.cx().val_ty(imm), from_backend_ty);",
"",
);
}

fs::write(out_path, src)?;
}
}
Expand Down
6 changes: 3 additions & 3 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
span = cx.tcx.def_span(adt.did());
}

let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs_unchecked(adt.did()));
let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_all_attrs(adt.did()));

if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value)
&& let Ok(spirv_type) =
Expand Down Expand Up @@ -791,7 +791,7 @@ fn trans_intrinsic_type<'tcx>(
let sampled_type = match args.type_at(0).kind() {
TyKind::Int(int) => match int {
IntTy::Isize => {
SpirvType::Integer(cx.tcx.data_layout.pointer_size.bits() as u32, true)
SpirvType::Integer(cx.tcx.data_layout.pointer_size().bits() as u32, true)
.def(span, cx)
}
IntTy::I8 => SpirvType::Integer(8, true).def(span, cx),
Expand All @@ -802,7 +802,7 @@ fn trans_intrinsic_type<'tcx>(
},
TyKind::Uint(uint) => match uint {
UintTy::Usize => {
SpirvType::Integer(cx.tcx.data_layout.pointer_size.bits() as u32, false)
SpirvType::Integer(cx.tcx.data_layout.pointer_size().bits() as u32, false)
.def(span, cx)
}
UintTy::U8 => SpirvType::Integer(8, false).def(span, cx),
Expand Down
33 changes: 26 additions & 7 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1825,10 +1825,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
self.declare_func_local_var(self.type_array(self.type_i8(), size.bytes()), align)
}

fn dynamic_alloca(&mut self, _len: Self::Value, _align: Align) -> Self::Value {
self.fatal("dynamic alloca not supported yet")
}

fn load(&mut self, ty: Self::Type, ptr: Self::Value, _align: Align) -> Self::Value {
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, ty);
let loaded_val = ptr.const_fold_load(self).unwrap_or_else(|| {
Expand Down Expand Up @@ -3253,6 +3249,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// ignore
}

#[tracing::instrument(
level = "debug",
skip(self, callee_ty, _fn_attrs, fn_abi, callee, args, funclet)
)]
fn call(
&mut self,
callee_ty: Self::Type,
Expand All @@ -3263,9 +3263,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
funclet: Option<&Self::Funclet>,
instance: Option<ty::Instance<'tcx>>,
) -> Self::Value {
let span = tracing::span!(tracing::Level::DEBUG, "call");
let _enter = span.enter();

if funclet.is_some() {
self.fatal("TODO: Funclets are not supported");
}
Expand Down Expand Up @@ -3387,6 +3384,15 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
}

// HACK(fee1-dead): `MaybeUninit` uses a union which we don't have very good support yet. Replacing all calls to it
// with an `Undef` serves the same purpose and fixes compiler errors
if instance_def_id.is_some_and(|did| {
self.tcx
.is_diagnostic_item(rustc_span::sym::maybe_uninit_uninit, did)
}) {
return self.undef(result_type);
}

// Default: emit a regular function call
let args = args.iter().map(|arg| arg.def(self)).collect::<Vec<_>>();
self.emit()
Expand All @@ -3395,6 +3401,19 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.with_type(result_type)
}

fn tail_call(
&mut self,
_llty: Self::Type,
_fn_attrs: Option<&CodegenFnAttrs>,
_fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
_llfn: Self::Value,
_args: &[Self::Value],
_funclet: Option<&Self::Funclet>,
_instance: Option<ty::Instance<'tcx>>,
) {
todo!()
}

fn zext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
self.intcast(val, dest_ty, false)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/builder_spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ impl<'tcx> BuilderSpirv<'tcx> {
FileName::Real(name) => {
name.to_string_lossy(FileNameDisplayPreference::Remapped)
}
_ => sf.name.prefer_remapped_unconditionaly().to_string().into(),
_ => sf.name.prefer_remapped_unconditionally().to_string().into(),
};
let file_name = {
// FIXME(eddyb) it should be possible to arena-allocate a
Expand Down
26 changes: 19 additions & 7 deletions crates/rustc_codegen_spirv/src/codegen_cx/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl ConstCodegenMethods for CodegenCx<'_> {
self.const_uint_big(ty, i)
}
fn const_usize(&self, i: u64) -> Self::Value {
let ptr_size = self.tcx.data_layout.pointer_size.bits() as u32;
let ptr_size = self.tcx.data_layout.pointer_size().bits() as u32;
let t = SpirvType::Integer(ptr_size, false).def(DUMMY_SP, self);
self.constant_int(t, i.into())
}
Expand Down Expand Up @@ -246,7 +246,7 @@ impl ConstCodegenMethods for CodegenCx<'_> {
}
}
Scalar::Ptr(ptr, _) => {
let (prov, offset) = ptr.into_parts();
let (prov, offset) = ptr.prov_and_relative_offset();
let alloc_id = prov.alloc_id();
let (base_addr, _base_addr_space) = match self.tcx.global_alloc(alloc_id) {
GlobalAlloc::Memory(alloc) => {
Expand All @@ -263,7 +263,7 @@ impl ConstCodegenMethods for CodegenCx<'_> {
.try_read_from_const_alloc(alloc, pointee)
.unwrap_or_else(|| self.const_data_from_alloc(alloc));
let value = self.static_addr_of(init, alloc.inner().align, None);
(value, AddressSpace::DATA)
(value, AddressSpace::ZERO)
}
GlobalAlloc::Function { instance } => (
self.get_fn_addr(instance),
Expand Down Expand Up @@ -292,12 +292,24 @@ impl ConstCodegenMethods for CodegenCx<'_> {
.try_read_from_const_alloc(alloc, pointee)
.unwrap_or_else(|| self.const_data_from_alloc(alloc));
let value = self.static_addr_of(init, alloc.inner().align, None);
(value, AddressSpace::DATA)
(value, AddressSpace::ZERO)
}
GlobalAlloc::Static(def_id) => {
assert!(self.tcx.is_static(def_id));
assert!(!self.tcx.is_thread_local_static(def_id));
(self.get_static(def_id), AddressSpace::DATA)
(self.get_static(def_id), AddressSpace::ZERO)
}
GlobalAlloc::TypeId { .. } => {
return if offset.bytes() == 0 {
self.constant_null(ty)
} else {
let result = self.undef(ty);
self.zombie_no_span(
result.def_cx(self),
"pointer has non-null integer address",
);
result
};
}
};
self.const_bitcast(self.const_ptr_byte_offset(base_addr, offset), ty)
Expand Down Expand Up @@ -430,7 +442,7 @@ impl<'tcx> CodegenCx<'tcx> {
.fatal(format!("invalid size for float: {other}"));
}
}),
SpirvType::Pointer { .. } => Primitive::Pointer(AddressSpace::DATA),
SpirvType::Pointer { .. } => Primitive::Pointer(AddressSpace::ZERO),
_ => unreachable!(),
};

Expand All @@ -449,7 +461,7 @@ impl<'tcx> CodegenCx<'tcx> {
.inner()
.read_scalar(self, range, /* read_provenance */ true)
{
let (prov, _offset) = ptr.into_parts();
let (prov, _offset) = ptr.prov_and_relative_offset();
primitive = Primitive::Pointer(
self.tcx.global_alloc(prov.alloc_id()).address_space(self),
);
Expand Down
6 changes: 3 additions & 3 deletions crates/rustc_codegen_spirv/src/codegen_cx/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use crate::spirv_type::SpirvType;
use itertools::Itertools;
use rspirv::spirv::{FunctionControl, LinkageType, StorageClass, Word};
use rustc_abi::Align;
use rustc_attr_data_structures::InlineAttr;
use rustc_codegen_ssa::traits::{PreDefineCodegenMethods, StaticCodegenMethods};
use rustc_hir::attrs::InlineAttr;
use rustc_middle::bug;
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs};
use rustc_middle::mir::mono::{Linkage, MonoItem, Visibility};
Expand Down Expand Up @@ -133,7 +133,7 @@ impl<'tcx> CodegenCx<'tcx> {
self.set_linkage(fn_id, symbol_name.to_owned(), linkage);
}

let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.get_attrs_unchecked(def_id));
let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.get_all_attrs(def_id));
if let Some(entry) = attrs.entry.map(|attr| attr.value) {
// HACK(eddyb) early insert to let `shader_entry_stub` call this
// very function via `get_fn_addr`.
Expand Down Expand Up @@ -167,7 +167,7 @@ impl<'tcx> CodegenCx<'tcx> {
}

// Check if this is a From trait implementation
if let Some(impl_def_id) = self.tcx.impl_of_method(def_id)
if let Some(impl_def_id) = self.tcx.impl_of_assoc(def_id)
&& let Some(trait_ref) = self.tcx.impl_trait_ref(impl_def_id)
{
let trait_def_id = trait_ref.skip_binder().def_id;
Expand Down
4 changes: 2 additions & 2 deletions crates/rustc_codegen_spirv/src/codegen_cx/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl<'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'tcx> {

impl<'tcx> CodegenCx<'tcx> {
pub fn type_usize(&self) -> Word {
let ptr_size = self.tcx.data_layout.pointer_size.bits() as u32;
let ptr_size = self.tcx.data_layout.pointer_size().bits() as u32;
SpirvType::Integer(ptr_size, false).def(DUMMY_SP, self)
}
}
Expand All @@ -146,7 +146,7 @@ impl BaseTypeCodegenMethods for CodegenCx<'_> {
SpirvType::Integer(128, false).def(DUMMY_SP, self)
}
fn type_isize(&self) -> Self::Type {
let ptr_size = self.tcx.data_layout.pointer_size.bits() as u32;
let ptr_size = self.tcx.data_layout.pointer_size().bits() as u32;
SpirvType::Integer(ptr_size, false).def(DUMMY_SP, self)
}

Expand Down
55 changes: 14 additions & 41 deletions crates/rustc_codegen_spirv/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ mod target_feature;

use builder::Builder;
use codegen_cx::CodegenCx;
use maybe_pqp_cg_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
use maybe_pqp_cg_ssa::back::lto::{SerializedModule, ThinModule};
use maybe_pqp_cg_ssa::back::write::{
CodegenContext, FatLtoInput, ModuleConfig, OngoingCodegen, TargetMachineFactoryConfig,
};
Expand Down Expand Up @@ -170,7 +170,7 @@ use std::any::Any;
use std::fs;
use std::io::Cursor;
use std::io::Write;
use std::path::Path;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tracing::{error, warn};

Expand Down Expand Up @@ -332,43 +332,33 @@ impl WriteBackendMethods for SpirvCodegenBackend {
type ThinData = ();
type ThinBuffer = SpirvModuleBuffer;

// FIXME(eddyb) reuse the "merge" stage of `crate::linker` for this, or even
// delegate to `run_fat_lto` (although `-Zcombine-cgu` is much more niche).
fn run_link(
cgcx: &CodegenContext<Self>,
diag_handler: DiagCtxtHandle<'_>,
_modules: Vec<ModuleCodegen<Self::Module>>,
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
assert!(
cgcx.opts.unstable_opts.combine_cgu,
"`run_link` (for `WorkItemResult::NeedsLink`) should \
only be invoked due to `-Zcombine-cgu`"
);
diag_handler.fatal("Rust-GPU does not support `-Zcombine-cgu`")
}

// FIXME(eddyb) reuse the "merge" stage of `crate::linker` for this, or even
// consider setting `requires_lto = true` in the target specs and moving the
// entirety of `crate::linker` into this stage (lacking diagnostics may be
// an issue - it's surprising `CodegenBackend::link` has `Session` at all).
fn run_fat_lto(
fn run_and_optimize_fat_lto(
cgcx: &CodegenContext<Self>,
_exported_symbols_for_lto: &[String],
_each_linked_rlib_for_lto: &[PathBuf],
_modules: Vec<FatLtoInput<Self>>,
_cached_modules: Vec<(SerializedModule<Self::ModuleBuffer>, WorkProduct)>,
) -> Result<LtoModuleCodegen<Self>, FatalError> {
_diff_fncs: Vec<AutoDiffItem>,
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
assert!(
cgcx.lto == rustc_session::config::Lto::Fat,
"`run_fat_lto` (for `WorkItemResult::NeedsFatLto`) should \
"`run_and_optimize_fat_lto` (for `WorkItemResult::NeedsFatLto`) should \
only be invoked due to `-Clto` (or equivalent)"
);
unreachable!("Rust-GPU does not support fat LTO")
}

fn run_thin_lto(
cgcx: &CodegenContext<Self>,
// FIXME(bjorn3): Limit LTO exports to these symbols
_exported_symbols_for_lto: &[String],
_each_linked_rlib_for_lto: &[PathBuf], // njn: ?
modules: Vec<(String, Self::ThinBuffer)>,
cached_modules: Vec<(SerializedModule<Self::ModuleBuffer>, WorkProduct)>,
) -> Result<(Vec<LtoModuleCodegen<Self>>, Vec<WorkProduct>), FatalError> {
) -> Result<(Vec<ThinModule<Self>>, Vec<WorkProduct>), FatalError> {
link::run_thin(cgcx, modules, cached_modules)
}

Expand Down Expand Up @@ -409,16 +399,8 @@ impl WriteBackendMethods for SpirvCodegenBackend {
Ok(module)
}

fn optimize_fat(
cgcx: &CodegenContext<Self>,
module: &mut ModuleCodegen<Self::Module>,
) -> Result<(), FatalError> {
Self::optimize_common(cgcx, module)
}

fn codegen(
cgcx: &CodegenContext<Self>,
_diag_handler: DiagCtxtHandle<'_>,
module: ModuleCodegen<Self::Module>,
_config: &ModuleConfig,
) -> Result<CompiledModule, FatalError> {
Expand Down Expand Up @@ -457,15 +439,6 @@ impl WriteBackendMethods for SpirvCodegenBackend {
SpirvModuleBuffer(module.module_llvm.assemble()),
)
}

fn autodiff(
_cgcx: &CodegenContext<Self>,
_module: &ModuleCodegen<Self::Module>,
_diff_fncs: Vec<AutoDiffItem>,
_config: &ModuleConfig,
) -> Result<(), FatalError> {
unreachable!("Rust-GPU does not support autodiff")
}
}

impl ExtraBackendMethods for SpirvCodegenBackend {
Expand Down Expand Up @@ -510,6 +483,7 @@ impl ExtraBackendMethods for SpirvCodegenBackend {

// ... and now that we have everything pre-defined, fill out those definitions.
for &(mono_item, mono_item_data) in &mono_items {
tracing::trace!(?mono_item, "defining");
mono_item.define::<Builder<'_, '_>>(cx, cgu_name.as_str(), mono_item_data);
}

Expand Down Expand Up @@ -547,8 +521,7 @@ impl ExtraBackendMethods for SpirvCodegenBackend {
_sess: &Session,
_opt_level: config::OptLevel,
_target_features: &[String],
) -> Arc<(dyn Fn(TargetMachineFactoryConfig) -> Result<(), String> + Send + Sync + 'static)>
{
) -> Arc<dyn Fn(TargetMachineFactoryConfig) -> Result<(), String> + Send + Sync + 'static> {
Arc::new(|_| Ok(()))
}
}
Expand Down
Loading
Loading