32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/FormatVariadic.h"
35#define DEBUG_TYPE "gpu-to-llvm-spv"
40#define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS
41#include "mlir/Conversion/Passes.h.inc"
51 Type resultType,
bool isMemNone,
53 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
57 func = LLVM::LLVMFuncOp::create(
59 LLVM::LLVMFunctionType::get(resultType, paramTypes));
60 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
61 func.setNoUnwind(
true);
62 func.setWillReturn(
true);
66 constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
67 auto memAttr =
b.getAttr<LLVM::MemoryEffectsAttr>(
73 func.setMemoryEffectsAttr(memAttr);
76 func.setConvergent(isConvergent);
82 ConversionPatternRewriter &rewriter,
83 LLVM::LLVMFuncOp
func,
85 auto call = LLVM::CallOp::create(rewriter, loc,
func, args);
86 call.setCConv(
func.getCConv());
87 call.setConvergentAttr(
func.getConvergentAttr());
88 call.setNoUnwindAttr(
func.getNoUnwindAttr());
89 call.setWillReturnAttr(
func.getWillReturnAttr());
90 call.setMemoryEffectsAttr(
func.getMemoryEffectsAttr());
110 matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter)
const final {
112 constexpr StringLiteral funcName =
"_Z7barrierj";
114 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
115 assert(moduleOp &&
"Expecting module");
116 Type flagTy = rewriter.getI32Type();
117 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
118 LLVM::LLVMFuncOp func =
124 constexpr int64_t localMemFenceFlag = 1;
125 Location loc = op->getLoc();
127 LLVM::ConstantOp::create(rewriter, loc, flagTy, localMemFenceFlag);
146 LaunchConfigConversion(StringRef funcName, StringRef rootOpName,
147 MLIRContext *context,
148 const LLVMTypeConverter &typeConverter,
149 PatternBenefit benefit)
150 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
151 funcName(funcName) {}
153 virtual gpu::Dimension getDimension(Operation *op)
const = 0;
156 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
157 ConversionPatternRewriter &rewriter)
const final {
159 assert(moduleOp &&
"Expecting module");
160 Type dimTy = rewriter.getI32Type();
161 Type indexTy = getTypeConverter()->getIndexType();
166 Location loc = op->getLoc();
167 gpu::Dimension dim = getDimension(op);
168 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
169 static_cast<int64_t
>(dim));
177template <
typename SourceOp>
178struct LaunchConfigOpConversion final : LaunchConfigConversion {
179 static StringRef getFuncName();
181 explicit LaunchConfigOpConversion(
const LLVMTypeConverter &typeConverter,
182 PatternBenefit benefit = 1)
183 : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(),
187 gpu::Dimension getDimension(Operation *op)
const final {
188 return cast<SourceOp>(op).getDimension();
193StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() {
194 return "_Z12get_group_idj";
198StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() {
199 return "_Z14get_num_groupsj";
203StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() {
204 return "_Z14get_local_sizej";
208StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() {
209 return "_Z12get_local_idj";
213StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() {
214 return "_Z13get_global_idj";
233 static StringRef getBaseName(gpu::ShuffleMode mode) {
235 case gpu::ShuffleMode::IDX:
236 return "sub_group_shuffle";
237 case gpu::ShuffleMode::XOR:
238 return "sub_group_shuffle_xor";
239 case gpu::ShuffleMode::UP:
240 return "sub_group_shuffle_up";
241 case gpu::ShuffleMode::DOWN:
242 return "sub_group_shuffle_down";
244 llvm_unreachable(
"Unhandled shuffle mode");
247 static std::optional<StringRef> getTypeMangling(Type type) {
249 .Case<Float16Type>([](
auto) {
return "Dhj"; })
250 .Case<Float32Type>([](
auto) {
return "fj"; })
251 .Case<Float64Type>([](
auto) {
return "dj"; })
252 .Case<IntegerType>([](
auto intTy) -> std::optional<StringRef> {
253 switch (intTy.getWidth()) {
265 .Default(std::nullopt);
268 static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
270 StringRef baseName = getBaseName(mode);
271 std::optional<StringRef> typeMangling = getTypeMangling(type);
274 return llvm::formatv(
"_Z{}{}{}", baseName.size(), baseName,
275 typeMangling.value());
279 static std::optional<int> getSubgroupSize(Operation *op) {
283 return parentFunc.getIntelReqdSubGroupSize();
286 static bool hasValidWidth(gpu::ShuffleOp op) {
288 Value width = op.getWidth();
290 val == getSubgroupSize(op);
293 static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc,
294 ConversionPatternRewriter &rewriter) {
296 .Case([&](BFloat16Type) {
297 return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(),
300 .Case([&](IntegerType intTy) -> Value {
301 if (intTy.getWidth() == 1)
302 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(),
309 static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy,
311 ConversionPatternRewriter &rewriter) {
313 .Case([&](BFloat16Type) {
314 return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal);
316 .Case([&](IntegerType intTy) -> Value {
317 if (intTy.getWidth() == 1)
318 return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal);
325 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter)
const final {
327 if (!hasValidWidth(op))
328 return rewriter.notifyMatchFailure(
329 op,
"shuffle width and subgroup size mismatch");
331 Location loc = op->getLoc();
333 bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter);
334 std::optional<std::string> funcName =
335 getFuncName(op.getMode(), inValue.
getType());
337 return rewriter.notifyMatchFailure(op,
"unsupported value type");
340 assert(moduleOp &&
"Expecting module");
341 Type valueType = inValue.
getType();
342 Type offsetType = adaptor.getOffset().getType();
343 Type resultType = valueType;
345 moduleOp, funcName.value(), {valueType, offsetType}, resultType,
348 std::array<Value, 2> args{inValue, adaptor.getOffset()};
351 Value resultOrConversion =
352 bitcastOrTruncAfterShuffle(
result, op.getType(0), loc, rewriter);
355 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(),
true);
356 rewriter.replaceOp(op, {resultOrConversion, trueVal});
361class MemorySpaceToOpenCLMemorySpaceConverter final :
public TypeConverter {
363 MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) {
364 addConversion([](Type t) {
return t; });
365 addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> {
372 spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup);
373 Attribute addrSpaceAttr =
374 IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace);
375 if (
auto rankedType = dyn_cast<MemRefType>(memRefType)) {
376 return MemRefType::get(memRefType.
getShape(),
378 rankedType.getLayout(), addrSpaceAttr);
383 addConversion([
this](FunctionType type) {
384 auto inputs = llvm::map_to_vector(
385 type.getInputs(), [
this](Type ty) { return convertType(ty); });
386 auto results = llvm::map_to_vector(
387 type.getResults(), [
this](Type ty) { return convertType(ty); });
388 return FunctionType::get(type.getContext(), inputs, results);
397template <
typename SubgroupOp>
399 using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
403 matchAndRewrite(SubgroupOp op,
typename SubgroupOp::Adaptor adaptor,
404 ConversionPatternRewriter &rewriter)
const final {
405 constexpr StringRef funcName = [] {
406 if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
407 return "_Z16get_sub_group_id";
408 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
409 return "_Z22get_sub_group_local_id";
410 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
411 return "_Z18get_num_sub_groups";
412 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
413 return "_Z18get_sub_group_size";
417 Operation *moduleOp =
418 op->template getParentWithTrait<OpTrait::SymbolTable>();
419 Type resultTy = rewriter.getI32Type();
420 LLVM::LLVMFuncOp func =
424 Location loc = op->getLoc();
427 Type indexTy = getTypeConverter()->getIndexType();
428 if (resultTy != indexTy) {
432 result = LLVM::ZExtOp::create(rewriter, loc, indexTy,
result);
435 rewriter.replaceOp(op,
result);
444struct GPUToLLVMSPVConversionPass final
448 void runOnOperation() final {
450 RewritePatternSet
patterns(context);
452 LowerToLLVMOptions
options(context);
453 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
454 LLVMTypeConverter converter(context,
options);
455 LLVMConversionTarget
target(*context);
459 MemorySpaceToOpenCLMemorySpaceConverter converter(context);
460 AttrTypeReplacer replacer;
462 -> std::optional<BaseMemRefType> {
463 return converter.convertType<BaseMemRefType>(origType);
472 target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
473 gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
474 gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
475 gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
476 gpu::ThreadIdOp, gpu::PrintfOp>();
480 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, 2,
481 LLVM::cconv::CConv::SPIR_FUNC,
482 "_Z6printfPU3AS2Kcz");
484 if (
failed(applyPartialConversion(getOperation(),
target,
498gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
499 constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL;
508 GPUSubgroupOpConversion<gpu::LaneIdOp>,
509 GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
510 GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
511 GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
512 LaunchConfigOpConversion<gpu::BlockDimOp>,
513 LaunchConfigOpConversion<gpu::BlockIdOp>,
514 LaunchConfigOpConversion<gpu::GlobalIdOp>,
515 LaunchConfigOpConversion<gpu::GridDimOp>,
516 LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
518 unsigned privateAddressSpace =
519 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private);
520 unsigned localAddressSpace =
521 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup);
522 OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context);
523 StringAttr kernelBlockSizeAttributeName =
524 LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName);
528 privateAddressSpace, localAddressSpace,
529 {}, kernelBlockSizeAttributeName,
530 LLVM::CConv::SPIR_KERNEL, LLVM::CConv::SPIR_FUNC,
536 gpuAddressSpaceToOCLAddressSpace);
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static llvm::ManagedStatic< PassManagerOptions > options
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for operation conversions targeting the LLVM IR dialect.
const LLVMTypeConverter * getTypeConverter() const
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type getType() const
Return the type of this value.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
AttrTypeReplacerBase.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateGpuToLLVMSPVConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
spirv::StorageClass addressSpaceToStorageClass(gpu::AddressSpace addressSpace)
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
llvm::TypeSwitch< T, ResultT > TypeSwitch