32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/FormatVariadic.h"
35#define DEBUG_TYPE "gpu-to-llvm-spv"
40#define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS
41#include "mlir/Conversion/Passes.h.inc"
51 Type resultType,
bool isMemNone,
53 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
57 func = LLVM::LLVMFuncOp::create(
59 LLVM::LLVMFunctionType::get(resultType, paramTypes));
60 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
61 func.setNoUnwind(
true);
62 func.setWillReturn(
true);
66 constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
67 auto memAttr =
b.getAttr<LLVM::MemoryEffectsAttr>(
73 func.setMemoryEffectsAttr(memAttr);
76 func.setConvergent(isConvergent);
82 ConversionPatternRewriter &rewriter,
83 LLVM::LLVMFuncOp
func,
85 auto call = LLVM::CallOp::create(rewriter, loc,
func, args);
86 call.setCConv(
func.getCConv());
87 call.setConvergentAttr(
func.getConvergentAttr());
88 call.setNoUnwindAttr(
func.getNoUnwindAttr());
89 call.setWillReturnAttr(
func.getWillReturnAttr());
90 call.setMemoryEffectsAttr(
func.getMemoryEffectsAttr());
112 matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor,
113 ConversionPatternRewriter &rewriter)
const final {
114 constexpr StringLiteral funcName =
"_Z7barrierj";
116 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
117 assert(moduleOp &&
"Expecting module");
118 Type flagTy = rewriter.getI32Type();
119 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
120 LLVM::LLVMFuncOp func =
126 constexpr int64_t localMemFenceFlag = 1;
127 constexpr int64_t globalMemFenceFlag = 2;
128 int64_t memFenceFlag = 0;
129 std::optional<ArrayAttr> addressSpaces = adaptor.getAddressSpaces();
131 for (Attribute attr : addressSpaces.value()) {
132 auto addressSpace = cast<gpu::AddressSpaceAttr>(attr).getValue();
133 switch (addressSpace) {
134 case gpu::AddressSpace::Global:
135 memFenceFlag = memFenceFlag | globalMemFenceFlag;
137 case gpu::AddressSpace::Workgroup:
138 memFenceFlag = memFenceFlag | localMemFenceFlag;
140 case gpu::AddressSpace::Private:
145 memFenceFlag = localMemFenceFlag | globalMemFenceFlag;
147 Location loc = op->getLoc();
148 Value flag = LLVM::ConstantOp::create(rewriter, loc, flagTy, memFenceFlag);
167 LaunchConfigConversion(StringRef funcName, StringRef rootOpName,
168 MLIRContext *context,
169 const LLVMTypeConverter &typeConverter,
170 PatternBenefit benefit)
171 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
172 funcName(funcName) {}
174 virtual gpu::Dimension getDimension(Operation *op)
const = 0;
177 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
178 ConversionPatternRewriter &rewriter)
const final {
180 assert(moduleOp &&
"Expecting module");
181 Type dimTy = rewriter.getI32Type();
182 Type indexTy = getTypeConverter()->getIndexType();
187 Location loc = op->getLoc();
188 gpu::Dimension dim = getDimension(op);
189 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
190 static_cast<int64_t
>(dim));
198template <
typename SourceOp>
199struct LaunchConfigOpConversion final : LaunchConfigConversion {
200 static StringRef getFuncName();
202 explicit LaunchConfigOpConversion(
const LLVMTypeConverter &typeConverter,
203 PatternBenefit benefit = 1)
204 : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(),
208 gpu::Dimension getDimension(Operation *op)
const final {
209 return cast<SourceOp>(op).getDimension();
214StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() {
215 return "_Z12get_group_idj";
219StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() {
220 return "_Z14get_num_groupsj";
224StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() {
225 return "_Z14get_local_sizej";
229StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() {
230 return "_Z12get_local_idj";
234StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() {
235 return "_Z13get_global_idj";
254 static StringRef getBaseName(gpu::ShuffleMode mode) {
256 case gpu::ShuffleMode::IDX:
257 return "sub_group_shuffle";
258 case gpu::ShuffleMode::XOR:
259 return "sub_group_shuffle_xor";
260 case gpu::ShuffleMode::UP:
261 return "sub_group_shuffle_up";
262 case gpu::ShuffleMode::DOWN:
263 return "sub_group_shuffle_down";
265 llvm_unreachable(
"Unhandled shuffle mode");
268 static std::optional<StringRef> getTypeMangling(Type type) {
270 .Case([](Float16Type) {
return "Dhj"; })
271 .Case([](Float32Type) {
return "fj"; })
272 .Case([](Float64Type) {
return "dj"; })
273 .Case([](IntegerType intTy) -> std::optional<StringRef> {
274 switch (intTy.getWidth()) {
286 .Default(std::nullopt);
289 static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
291 StringRef baseName = getBaseName(mode);
292 std::optional<StringRef> typeMangling = getTypeMangling(type);
295 return llvm::formatv(
"_Z{}{}{}", baseName.size(), baseName,
296 typeMangling.value());
300 static std::optional<int> getSubgroupSize(Operation *op) {
304 return parentFunc.getIntelReqdSubGroupSize();
307 static bool hasValidWidth(gpu::ShuffleOp op) {
309 Value width = op.getWidth();
311 val == getSubgroupSize(op);
314 static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc,
315 ConversionPatternRewriter &rewriter) {
317 .Case([&](BFloat16Type) {
318 return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(),
321 .Case([&](IntegerType intTy) -> Value {
322 if (intTy.getWidth() == 1)
323 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(),
330 static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy,
332 ConversionPatternRewriter &rewriter) {
334 .Case([&](BFloat16Type) {
335 return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal);
337 .Case([&](IntegerType intTy) -> Value {
338 if (intTy.getWidth() == 1)
339 return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal);
346 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
347 ConversionPatternRewriter &rewriter)
const final {
348 if (!hasValidWidth(op))
349 return rewriter.notifyMatchFailure(
350 op,
"shuffle width and subgroup size mismatch");
352 Location loc = op->getLoc();
354 bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter);
355 std::optional<std::string> funcName =
356 getFuncName(op.getMode(), inValue.
getType());
358 return rewriter.notifyMatchFailure(op,
"unsupported value type");
361 assert(moduleOp &&
"Expecting module");
362 Type valueType = inValue.
getType();
363 Type offsetType = adaptor.getOffset().getType();
364 Type resultType = valueType;
366 moduleOp, funcName.value(), {valueType, offsetType}, resultType,
369 std::array<Value, 2> args{inValue, adaptor.getOffset()};
372 Value resultOrConversion =
373 bitcastOrTruncAfterShuffle(
result, op.getType(0), loc, rewriter);
376 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(),
true);
377 rewriter.replaceOp(op, {resultOrConversion, trueVal});
382class MemorySpaceToOpenCLMemorySpaceConverter final :
public TypeConverter {
384 MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) {
385 addConversion([](Type t) {
return t; });
386 addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> {
393 spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup);
394 Attribute addrSpaceAttr =
395 IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace);
396 if (
auto rankedType = dyn_cast<MemRefType>(memRefType)) {
397 return MemRefType::get(memRefType.
getShape(),
399 rankedType.getLayout(), addrSpaceAttr);
404 addConversion([
this](FunctionType type) {
405 auto inputs = llvm::map_to_vector(
406 type.getInputs(), [
this](Type ty) { return convertType(ty); });
407 auto results = llvm::map_to_vector(
408 type.getResults(), [
this](Type ty) { return convertType(ty); });
409 return FunctionType::get(type.getContext(), inputs, results);
418template <
typename SubgroupOp>
420 using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
424 matchAndRewrite(SubgroupOp op,
typename SubgroupOp::Adaptor adaptor,
425 ConversionPatternRewriter &rewriter)
const final {
426 constexpr StringRef funcName = [] {
427 if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
428 return "_Z16get_sub_group_id";
429 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
430 return "_Z22get_sub_group_local_id";
431 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
432 return "_Z18get_num_sub_groups";
433 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
434 return "_Z18get_sub_group_size";
438 Operation *moduleOp =
439 op->template getParentWithTrait<OpTrait::SymbolTable>();
440 Type resultTy = rewriter.getI32Type();
441 LLVM::LLVMFuncOp func =
445 Location loc = op->getLoc();
448 Type indexTy = getTypeConverter()->getIndexType();
449 if (resultTy != indexTy) {
453 result = LLVM::ZExtOp::create(rewriter, loc, indexTy,
result);
456 rewriter.replaceOp(op,
result);
465struct GPUToLLVMSPVConversionPass final
466 : impl::ConvertGpuOpsToLLVMSPVOpsBase<GPUToLLVMSPVConversionPass> {
469 void runOnOperation() final {
471 RewritePatternSet
patterns(context);
473 LowerToLLVMOptions
options(context);
474 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
475 LLVMTypeConverter converter(context,
options);
476 LLVMConversionTarget
target(*context);
480 MemorySpaceToOpenCLMemorySpaceConverter converter(context);
481 AttrTypeReplacer replacer;
483 -> std::optional<BaseMemRefType> {
484 return converter.convertType<BaseMemRefType>(origType);
493 target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
494 gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
495 gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
496 gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
497 gpu::ThreadIdOp, gpu::PrintfOp>();
501 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, 2,
502 LLVM::cconv::CConv::SPIR_FUNC,
503 "_Z6printfPU3AS2Kcz");
505 if (
failed(applyPartialConversion(getOperation(),
target,
519gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
520 constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL;
529 GPUSubgroupOpConversion<gpu::LaneIdOp>,
530 GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
531 GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
532 GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
533 LaunchConfigOpConversion<gpu::BlockDimOp>,
534 LaunchConfigOpConversion<gpu::BlockIdOp>,
535 LaunchConfigOpConversion<gpu::GlobalIdOp>,
536 LaunchConfigOpConversion<gpu::GridDimOp>,
537 LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
539 unsigned privateAddressSpace =
540 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private);
541 unsigned localAddressSpace =
542 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup);
543 OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context);
544 StringAttr kernelBlockSizeAttributeName =
545 LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName);
549 privateAddressSpace, localAddressSpace,
550 {}, kernelBlockSizeAttributeName,
551 {}, LLVM::CConv::SPIR_KERNEL,
552 LLVM::CConv::SPIR_FUNC,
558 gpuAddressSpaceToOCLAddressSpace);
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static llvm::ManagedStatic< PassManagerOptions > options
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for operation conversions targeting the LLVM IR dialect.
const LLVMTypeConverter * getTypeConverter() const
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type getType() const
Return the type of this value.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
AttrTypeReplacerBase.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateGpuToLLVMSPVConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
spirv::StorageClass addressSpaceToStorageClass(gpu::AddressSpace addressSpace)
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
llvm::TypeSwitch< T, ResultT > TypeSwitch