32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/FormatVariadic.h"
35#define DEBUG_TYPE "gpu-to-llvm-spv"
40#define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS
41#include "mlir/Conversion/Passes.h.inc"
51 Type resultType,
bool isMemNone,
53 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
57 func = LLVM::LLVMFuncOp::create(
59 LLVM::LLVMFunctionType::get(resultType, paramTypes));
60 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
61 func.setNoUnwind(
true);
62 func.setWillReturn(
true);
66 constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
67 auto memAttr =
b.getAttr<LLVM::MemoryEffectsAttr>(
73 func.setMemoryEffectsAttr(memAttr);
76 func.setConvergent(isConvergent);
82 ConversionPatternRewriter &rewriter,
83 LLVM::LLVMFuncOp
func,
85 auto call = LLVM::CallOp::create(rewriter, loc,
func, args);
86 call.setCConv(
func.getCConv());
87 call.setConvergentAttr(
func.getConvergentAttr());
88 call.setNoUnwindAttr(
func.getNoUnwindAttr());
89 call.setWillReturnAttr(
func.getWillReturnAttr());
90 call.setMemoryEffectsAttr(
func.getMemoryEffectsAttr());
112 matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor,
113 ConversionPatternRewriter &rewriter)
const final {
114 constexpr StringLiteral funcName =
"_Z7barrierj";
116 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
117 assert(moduleOp &&
"Expecting module");
118 Type flagTy = rewriter.getI32Type();
119 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
120 LLVM::LLVMFuncOp func =
127 constexpr int64_t localMemFenceFlag = 1;
128 constexpr int64_t globalMemFenceFlag = 2;
129 constexpr int64_t localGlobalMemFenceFlag =
130 localMemFenceFlag | globalMemFenceFlag;
131 Location loc = op->getLoc();
132 Value flag = LLVM::ConstantOp::create(rewriter, loc, flagTy,
133 localGlobalMemFenceFlag);
152 LaunchConfigConversion(StringRef funcName, StringRef rootOpName,
153 MLIRContext *context,
154 const LLVMTypeConverter &typeConverter,
155 PatternBenefit benefit)
156 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
157 funcName(funcName) {}
159 virtual gpu::Dimension getDimension(Operation *op)
const = 0;
162 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
163 ConversionPatternRewriter &rewriter)
const final {
165 assert(moduleOp &&
"Expecting module");
166 Type dimTy = rewriter.getI32Type();
167 Type indexTy = getTypeConverter()->getIndexType();
172 Location loc = op->getLoc();
173 gpu::Dimension dim = getDimension(op);
174 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
175 static_cast<int64_t
>(dim));
183template <
typename SourceOp>
184struct LaunchConfigOpConversion final : LaunchConfigConversion {
185 static StringRef getFuncName();
187 explicit LaunchConfigOpConversion(
const LLVMTypeConverter &typeConverter,
188 PatternBenefit benefit = 1)
189 : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(),
193 gpu::Dimension getDimension(Operation *op)
const final {
194 return cast<SourceOp>(op).getDimension();
199StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() {
200 return "_Z12get_group_idj";
204StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() {
205 return "_Z14get_num_groupsj";
209StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() {
210 return "_Z14get_local_sizej";
214StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() {
215 return "_Z12get_local_idj";
219StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() {
220 return "_Z13get_global_idj";
239 static StringRef getBaseName(gpu::ShuffleMode mode) {
241 case gpu::ShuffleMode::IDX:
242 return "sub_group_shuffle";
243 case gpu::ShuffleMode::XOR:
244 return "sub_group_shuffle_xor";
245 case gpu::ShuffleMode::UP:
246 return "sub_group_shuffle_up";
247 case gpu::ShuffleMode::DOWN:
248 return "sub_group_shuffle_down";
250 llvm_unreachable(
"Unhandled shuffle mode");
253 static std::optional<StringRef> getTypeMangling(Type type) {
255 .Case<Float16Type>([](
auto) {
return "Dhj"; })
256 .Case<Float32Type>([](
auto) {
return "fj"; })
257 .Case<Float64Type>([](
auto) {
return "dj"; })
258 .Case<IntegerType>([](
auto intTy) -> std::optional<StringRef> {
259 switch (intTy.getWidth()) {
271 .Default(std::nullopt);
274 static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
276 StringRef baseName = getBaseName(mode);
277 std::optional<StringRef> typeMangling = getTypeMangling(type);
280 return llvm::formatv(
"_Z{}{}{}", baseName.size(), baseName,
281 typeMangling.value());
285 static std::optional<int> getSubgroupSize(Operation *op) {
289 return parentFunc.getIntelReqdSubGroupSize();
292 static bool hasValidWidth(gpu::ShuffleOp op) {
294 Value width = op.getWidth();
296 val == getSubgroupSize(op);
299 static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc,
300 ConversionPatternRewriter &rewriter) {
302 .Case([&](BFloat16Type) {
303 return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(),
306 .Case([&](IntegerType intTy) -> Value {
307 if (intTy.getWidth() == 1)
308 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(),
315 static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy,
317 ConversionPatternRewriter &rewriter) {
319 .Case([&](BFloat16Type) {
320 return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal);
322 .Case([&](IntegerType intTy) -> Value {
323 if (intTy.getWidth() == 1)
324 return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal);
331 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
332 ConversionPatternRewriter &rewriter)
const final {
333 if (!hasValidWidth(op))
334 return rewriter.notifyMatchFailure(
335 op,
"shuffle width and subgroup size mismatch");
337 Location loc = op->getLoc();
339 bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter);
340 std::optional<std::string> funcName =
341 getFuncName(op.getMode(), inValue.
getType());
343 return rewriter.notifyMatchFailure(op,
"unsupported value type");
346 assert(moduleOp &&
"Expecting module");
347 Type valueType = inValue.
getType();
348 Type offsetType = adaptor.getOffset().getType();
349 Type resultType = valueType;
351 moduleOp, funcName.value(), {valueType, offsetType}, resultType,
354 std::array<Value, 2> args{inValue, adaptor.getOffset()};
357 Value resultOrConversion =
358 bitcastOrTruncAfterShuffle(
result, op.getType(0), loc, rewriter);
361 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(),
true);
362 rewriter.replaceOp(op, {resultOrConversion, trueVal});
367class MemorySpaceToOpenCLMemorySpaceConverter final :
public TypeConverter {
369 MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) {
370 addConversion([](Type t) {
return t; });
371 addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> {
378 spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup);
379 Attribute addrSpaceAttr =
380 IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace);
381 if (
auto rankedType = dyn_cast<MemRefType>(memRefType)) {
382 return MemRefType::get(memRefType.
getShape(),
384 rankedType.getLayout(), addrSpaceAttr);
389 addConversion([
this](FunctionType type) {
390 auto inputs = llvm::map_to_vector(
391 type.getInputs(), [
this](Type ty) { return convertType(ty); });
392 auto results = llvm::map_to_vector(
393 type.getResults(), [
this](Type ty) { return convertType(ty); });
394 return FunctionType::get(type.getContext(), inputs, results);
403template <
typename SubgroupOp>
405 using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
409 matchAndRewrite(SubgroupOp op,
typename SubgroupOp::Adaptor adaptor,
410 ConversionPatternRewriter &rewriter)
const final {
411 constexpr StringRef funcName = [] {
412 if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
413 return "_Z16get_sub_group_id";
414 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
415 return "_Z22get_sub_group_local_id";
416 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
417 return "_Z18get_num_sub_groups";
418 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
419 return "_Z18get_sub_group_size";
423 Operation *moduleOp =
424 op->template getParentWithTrait<OpTrait::SymbolTable>();
425 Type resultTy = rewriter.getI32Type();
426 LLVM::LLVMFuncOp func =
430 Location loc = op->getLoc();
433 Type indexTy = getTypeConverter()->getIndexType();
434 if (resultTy != indexTy) {
438 result = LLVM::ZExtOp::create(rewriter, loc, indexTy,
result);
441 rewriter.replaceOp(op,
result);
450struct GPUToLLVMSPVConversionPass final
454 void runOnOperation() final {
456 RewritePatternSet
patterns(context);
458 LowerToLLVMOptions
options(context);
459 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
460 LLVMTypeConverter converter(context,
options);
461 LLVMConversionTarget
target(*context);
465 MemorySpaceToOpenCLMemorySpaceConverter converter(context);
466 AttrTypeReplacer replacer;
468 -> std::optional<BaseMemRefType> {
469 return converter.convertType<BaseMemRefType>(origType);
478 target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
479 gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
480 gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
481 gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
482 gpu::ThreadIdOp, gpu::PrintfOp>();
486 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, 2,
487 LLVM::cconv::CConv::SPIR_FUNC,
488 "_Z6printfPU3AS2Kcz");
490 if (
failed(applyPartialConversion(getOperation(),
target,
504gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
505 constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL;
514 GPUSubgroupOpConversion<gpu::LaneIdOp>,
515 GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
516 GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
517 GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
518 LaunchConfigOpConversion<gpu::BlockDimOp>,
519 LaunchConfigOpConversion<gpu::BlockIdOp>,
520 LaunchConfigOpConversion<gpu::GlobalIdOp>,
521 LaunchConfigOpConversion<gpu::GridDimOp>,
522 LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
524 unsigned privateAddressSpace =
525 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private);
526 unsigned localAddressSpace =
527 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup);
528 OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context);
529 StringAttr kernelBlockSizeAttributeName =
530 LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName);
534 privateAddressSpace, localAddressSpace,
535 {}, kernelBlockSizeAttributeName,
536 LLVM::CConv::SPIR_KERNEL, LLVM::CConv::SPIR_FUNC,
542 gpuAddressSpaceToOCLAddressSpace);
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static llvm::ManagedStatic< PassManagerOptions > options
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for operation conversions targeting the LLVM IR dialect.
const LLVMTypeConverter * getTypeConverter() const
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type getType() const
Return the type of this value.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
AttrTypeReplacerBase.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateGpuToLLVMSPVConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
spirv::StorageClass addressSpaceToStorageClass(gpu::AddressSpace addressSpace)
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
llvm::TypeSwitch< T, ResultT > TypeSwitch