32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/FormatVariadic.h"
35#define DEBUG_TYPE "gpu-to-llvm-spv"
40#define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS
41#include "mlir/Conversion/Passes.h.inc"
51 Type resultType,
bool isMemNone,
53 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
57 func = LLVM::LLVMFuncOp::create(
59 LLVM::LLVMFunctionType::get(resultType, paramTypes));
60 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
61 func.setNoUnwind(
true);
62 func.setWillReturn(
true);
66 constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
67 auto memAttr =
b.getAttr<LLVM::MemoryEffectsAttr>(
73 func.setMemoryEffectsAttr(memAttr);
76 func.setConvergent(isConvergent);
82 ConversionPatternRewriter &rewriter,
83 LLVM::LLVMFuncOp
func,
85 auto call = LLVM::CallOp::create(rewriter, loc,
func, args);
86 call.setCConv(
func.getCConv());
87 call.setConvergentAttr(
func.getConvergentAttr());
88 call.setNoUnwindAttr(
func.getNoUnwindAttr());
89 call.setWillReturnAttr(
func.getWillReturnAttr());
90 call.setMemoryEffectsAttr(
func.getMemoryEffectsAttr());
112 matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor,
113 ConversionPatternRewriter &rewriter)
const final {
114 constexpr StringLiteral funcName =
"_Z7barrierj";
116 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
117 assert(moduleOp &&
"Expecting module");
118 Type flagTy = rewriter.getI32Type();
119 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
120 LLVM::LLVMFuncOp func =
126 constexpr int64_t localMemFenceFlag = 1;
127 constexpr int64_t globalMemFenceFlag = 2;
128 int64_t memFenceFlag = 0;
129 std::optional<ArrayAttr> addressSpaces = adaptor.getAddressSpaces();
131 for (Attribute attr : addressSpaces.value()) {
132 auto addressSpace = cast<gpu::AddressSpaceAttr>(attr).getValue();
133 switch (addressSpace) {
134 case gpu::AddressSpace::Global:
135 memFenceFlag = memFenceFlag | globalMemFenceFlag;
137 case gpu::AddressSpace::Workgroup:
138 memFenceFlag = memFenceFlag | localMemFenceFlag;
140 case gpu::AddressSpace::Private:
141 case gpu::AddressSpace::Constant:
147 memFenceFlag = localMemFenceFlag | globalMemFenceFlag;
149 Location loc = op->getLoc();
150 Value flag = LLVM::ConstantOp::create(rewriter, loc, flagTy, memFenceFlag);
169 LaunchConfigConversion(StringRef funcName, StringRef rootOpName,
170 MLIRContext *context,
171 const LLVMTypeConverter &typeConverter,
172 PatternBenefit benefit)
173 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
174 funcName(funcName) {}
176 virtual gpu::Dimension getDimension(Operation *op)
const = 0;
179 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
180 ConversionPatternRewriter &rewriter)
const final {
182 assert(moduleOp &&
"Expecting module");
183 Type dimTy = rewriter.getI32Type();
184 Type indexTy = getTypeConverter()->getIndexType();
189 Location loc = op->getLoc();
190 gpu::Dimension dim = getDimension(op);
191 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
192 static_cast<int64_t
>(dim));
200template <
typename SourceOp>
201struct LaunchConfigOpConversion final : LaunchConfigConversion {
202 static StringRef getFuncName();
204 explicit LaunchConfigOpConversion(
const LLVMTypeConverter &typeConverter,
205 PatternBenefit benefit = 1)
206 : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(),
210 gpu::Dimension getDimension(Operation *op)
const final {
211 return cast<SourceOp>(op).getDimension();
216StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() {
217 return "_Z12get_group_idj";
221StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() {
222 return "_Z14get_num_groupsj";
226StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() {
227 return "_Z14get_local_sizej";
231StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() {
232 return "_Z12get_local_idj";
236StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() {
237 return "_Z13get_global_idj";
256 static StringRef getBaseName(gpu::ShuffleMode mode) {
258 case gpu::ShuffleMode::IDX:
259 return "sub_group_shuffle";
260 case gpu::ShuffleMode::XOR:
261 return "sub_group_shuffle_xor";
262 case gpu::ShuffleMode::UP:
263 return "sub_group_shuffle_up";
264 case gpu::ShuffleMode::DOWN:
265 return "sub_group_shuffle_down";
267 llvm_unreachable(
"Unhandled shuffle mode");
270 static std::optional<StringRef> getTypeMangling(Type type) {
272 .Case([](Float16Type) {
return "Dhj"; })
273 .Case([](Float32Type) {
return "fj"; })
274 .Case([](Float64Type) {
return "dj"; })
275 .Case([](IntegerType intTy) -> std::optional<StringRef> {
276 switch (intTy.getWidth()) {
288 .Default(std::nullopt);
291 static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
293 StringRef baseName = getBaseName(mode);
294 std::optional<StringRef> typeMangling = getTypeMangling(type);
297 return llvm::formatv(
"_Z{}{}{}", baseName.size(), baseName,
298 typeMangling.value());
302 static std::optional<int> getSubgroupSize(Operation *op) {
306 return parentFunc.getIntelReqdSubGroupSize();
309 static bool hasValidWidth(gpu::ShuffleOp op,
int subgroupSize) {
311 Value width = op.getWidth();
315 static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc,
316 ConversionPatternRewriter &rewriter) {
318 .Case([&](BFloat16Type) {
319 return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(),
322 .Case([&](IntegerType intTy) -> Value {
323 if (intTy.getWidth() == 1)
324 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(),
331 static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy,
333 ConversionPatternRewriter &rewriter) {
335 .Case([&](BFloat16Type) {
336 return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal);
338 .Case([&](IntegerType intTy) -> Value {
339 if (intTy.getWidth() == 1)
340 return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal);
347 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
348 ConversionPatternRewriter &rewriter)
const final {
349 auto maybeSubgroupSize = getSubgroupSize(op);
350 if (maybeSubgroupSize && !hasValidWidth(op, maybeSubgroupSize.value()))
351 return rewriter.notifyMatchFailure(
352 op,
"shuffle width and subgroup size mismatch");
354 Location loc = op->getLoc();
356 bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter);
357 std::optional<std::string> funcName =
358 getFuncName(op.getMode(), inValue.
getType());
360 return rewriter.notifyMatchFailure(op,
"unsupported value type");
363 assert(moduleOp &&
"Expecting module");
364 Type valueType = inValue.
getType();
365 Type offsetType = adaptor.getOffset().getType();
366 Type resultType = valueType;
368 moduleOp, funcName.value(), {valueType, offsetType}, resultType,
371 std::array<Value, 2> args{inValue, adaptor.getOffset()};
374 Value resultOrConversion =
375 bitcastOrTruncAfterShuffle(
result, op.getType(0), loc, rewriter);
378 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(),
true);
379 rewriter.replaceOp(op, {resultOrConversion, trueVal});
384class MemorySpaceToOpenCLMemorySpaceConverter final :
public TypeConverter {
386 MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) {
387 addConversion([](Type t) {
return t; });
388 addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> {
395 spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup);
396 Attribute addrSpaceAttr =
397 IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace);
398 if (
auto rankedType = dyn_cast<MemRefType>(memRefType)) {
399 return MemRefType::get(memRefType.
getShape(),
401 rankedType.getLayout(), addrSpaceAttr);
406 addConversion([
this](FunctionType type) {
407 auto inputs = llvm::map_to_vector(
408 type.getInputs(), [
this](Type ty) { return convertType(ty); });
409 auto results = llvm::map_to_vector(
410 type.getResults(), [
this](Type ty) { return convertType(ty); });
411 return FunctionType::get(type.getContext(), inputs, results);
420template <
typename SubgroupOp>
422 using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
426 matchAndRewrite(SubgroupOp op,
typename SubgroupOp::Adaptor adaptor,
427 ConversionPatternRewriter &rewriter)
const final {
428 constexpr StringRef funcName = [] {
429 if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
430 return "_Z16get_sub_group_id";
431 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
432 return "_Z22get_sub_group_local_id";
433 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
434 return "_Z18get_num_sub_groups";
435 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
436 return "_Z18get_sub_group_size";
440 Operation *moduleOp =
441 op->template getParentWithTrait<OpTrait::SymbolTable>();
442 Type resultTy = rewriter.getI32Type();
443 LLVM::LLVMFuncOp func =
447 Location loc = op->getLoc();
450 Type indexTy = getTypeConverter()->getIndexType();
451 if (resultTy != indexTy) {
455 result = LLVM::ZExtOp::create(rewriter, loc, indexTy,
result);
458 rewriter.replaceOp(op,
result);
467struct GPUToLLVMSPVConversionPass final
468 : impl::ConvertGpuOpsToLLVMSPVOpsBase<GPUToLLVMSPVConversionPass> {
471 void runOnOperation() final {
473 RewritePatternSet patterns(context);
475 LowerToLLVMOptions
options(context);
476 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
477 LLVMTypeConverter converter(context,
options);
478 LLVMConversionTarget
target(*context);
482 MemorySpaceToOpenCLMemorySpaceConverter converter(context);
483 AttrTypeReplacer replacer;
485 -> std::optional<BaseMemRefType> {
486 return converter.convertType<BaseMemRefType>(origType);
495 target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
496 gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
497 gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
498 gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
499 gpu::ThreadIdOp, gpu::PrintfOp>();
503 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, 2,
504 LLVM::cconv::CConv::SPIR_FUNC,
505 "_Z6printfPU3AS2Kcz");
507 if (
failed(applyPartialConversion(getOperation(),
target,
508 std::move(patterns))))
521gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
522 constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL;
531 GPUSubgroupOpConversion<gpu::LaneIdOp>,
532 GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
533 GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
534 GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
535 LaunchConfigOpConversion<gpu::BlockDimOp>,
536 LaunchConfigOpConversion<gpu::BlockIdOp>,
537 LaunchConfigOpConversion<gpu::GlobalIdOp>,
538 LaunchConfigOpConversion<gpu::GridDimOp>,
539 LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
541 unsigned privateAddressSpace =
542 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private);
543 unsigned localAddressSpace =
544 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup);
545 OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context);
546 StringAttr kernelBlockSizeAttributeName =
547 LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName);
551 privateAddressSpace, localAddressSpace,
552 {}, kernelBlockSizeAttributeName,
553 {}, LLVM::CConv::SPIR_KERNEL,
554 LLVM::CConv::SPIR_FUNC,
560 gpuAddressSpaceToOCLAddressSpace);
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static llvm::ManagedStatic< PassManagerOptions > options
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for operation conversions targeting the LLVM IR dialect.
const LLVMTypeConverter * getTypeConverter() const
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type getType() const
Return the type of this value.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
AttrTypeReplacerBase.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateGpuToLLVMSPVConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
spirv::StorageClass addressSpaceToStorageClass(gpu::AddressSpace addressSpace)
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
llvm::TypeSwitch< T, ResultT > TypeSwitch