@@ -470,7 +470,9 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
470470 Name = getStaticDeclName (*this , D);
471471
472472 mlir::Type LTy = getTypes ().convertTypeForMem (Ty);
473- cir::AddressSpaceAttr AS =
473+
474+ // The address space determined by __attribute__((addrspace(n))).
475+ cir::AddressSpaceAttr actualAS =
474476 builder.getAddrSpaceAttr (getGlobalVarAddressSpace (&D));
475477
476478 // OpenCL variables in local address space and CUDA shared
@@ -482,8 +484,9 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
482484 !D.hasAttr <CUDASharedAttr>())
483485 Init = builder.getZeroInitAttr (convertType (Ty));
484486
485- cir::GlobalOp GV = builder.createVersionedGlobal (
486- getModule (), getLoc (D.getLocation ()), Name, LTy, false , Linkage, AS);
487+ cir::GlobalOp GV =
488+ builder.createVersionedGlobal (getModule (), getLoc (D.getLocation ()), Name,
489+ LTy, false , Linkage, actualAS);
487490 // TODO(cir): infer visibility from linkage in global op builder.
488491 GV.setVisibility (getMLIRVisibilityFromCIRLinkage (Linkage));
489492 GV.setInitialValueAttr (Init);
@@ -497,14 +500,15 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
497500
498501 setGVProperties (GV, &D);
499502
500- // OG checks if the expected address space, denoted by the type, is the
501- // same as the actual address space indicated by attributes. If they aren't
502- // the same, an addrspacecast is emitted when this variable is accessed.
503- // In CIR however, cir.get_global alreadys carries that information in
504- // !cir.ptr type - if this global is in OpenCL local address space, then its
505- // type would be !cir.ptr<..., addrspace(offload_local)>. Therefore we don't
506- // need an explicit address space cast in CIR: they will get emitted when
507- // lowering to LLVM IR.
503+ // OG checks whether the expected address space (AS), denoted by
504+ // __attributes__((addrspace(n))), is the same as the actual AS indicated by
505+ // other attributes (such as __device__ in CUDA). If they aren't the same, an
506+ // addrspacecast is emitted when this variable is accessed, which means we
507+ // need it in this function. In CIR however, since we access globals by
508+ // `cir.get_global`, we won't emit a cast for GlobalOp here. Instead, we
509+ // record the AST, and create a CastOp in
510+ // `CIRGenBaseBuilder::createGetGlobal`.
511+ GV.setAstAttr (cir::ASTVarDeclAttr::get (&getMLIRContext (), &D));
508512
509513 // Ensure that the static local gets initialized by making sure the parent
510514 // function gets emitted eventually.
@@ -617,7 +621,10 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
617621 // TODO(cir): we should have a way to represent global ops as values without
618622 // having to emit a get global op. Sometimes these emissions are not used.
619623 auto addr = getBuilder ().createGetGlobal (globalOp);
620- auto getAddrOp = mlir::cast<cir::GetGlobalOp>(addr.getDefiningOp ());
624+ auto definingOp = addr.getDefiningOp ();
625+ bool hasCast = isa<cir::CastOp>(definingOp);
626+ auto getAddrOp = mlir::cast<cir::GetGlobalOp>(
627+ hasCast ? definingOp->getOperand (0 ).getDefiningOp () : definingOp);
621628
622629 CharUnits alignment = getContext ().getDeclAlign (&D);
623630
@@ -633,7 +640,7 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
633640 llvm_unreachable (" VLAs are NYI" );
634641
635642 // Save the type in case adding the initializer forces a type change.
636- auto expectedType = addr.getType ();
643+ auto expectedType = cast<cir::PointerType>( addr.getType () );
637644
638645 auto var = globalOp;
639646
@@ -678,7 +685,25 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
678685 //
679686 // FIXME: It is really dangerous to store this in the map; if anyone
680687 // RAUW's the GV uses of this constant will be invalid.
681- auto castedAddr = builder.createBitcast (getAddrOp.getAddr (), expectedType);
688+ mlir::Value castedAddr;
689+ if (!hasCast)
690+ castedAddr = builder.createBitcast (getAddrOp.getAddr (), expectedType);
691+ else {
692+ // If there is an extra CastOp from createGetGlobal, we need to remove the
693+ // existing addrspacecast, then supply a bitcast and a new addrspacecast:
694+ // %1 = cir.get_global @addr
695+ // %2 = cir.cast(addrspacecast, %1) <--- remove
696+ // %2 = cir.cast(bitcast, %1) <--- insert
697+ // %3 = cir.cast(addrspacecast, %2) <--- insert
698+ definingOp->erase ();
699+
700+ auto expectedTypeWithAS = cir::PointerType::get (
701+ expectedType.getPointee (), getAddrOp.getType ().getAddrSpace ());
702+ auto converted =
703+ builder.createBitcast (getAddrOp.getAddr (), expectedTypeWithAS);
704+ castedAddr = builder.createAddrSpaceCast (converted, expectedType);
705+ }
706+
682707 LocalDeclMap.find (&D)->second = Address (castedAddr, elemTy, alignment);
683708 CGM.setStaticLocalDeclAddress (&D, var);
684709
0 commit comments