32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/FormatVariadic.h"
35#define DEBUG_TYPE "gpu-to-llvm-spv"
40#define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS
41#include "mlir/Conversion/Passes.h.inc"
51 Type resultType,
bool isMemNone,
53 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
57 func = LLVM::LLVMFuncOp::create(
59 LLVM::LLVMFunctionType::get(resultType, paramTypes));
60 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
61 func.setNoUnwind(
true);
62 func.setWillReturn(
true);
66 constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
67 auto memAttr =
b.getAttr<LLVM::MemoryEffectsAttr>(
70 func.setMemoryEffectsAttr(memAttr);
73 func.setConvergent(isConvergent);
79 ConversionPatternRewriter &rewriter,
80 LLVM::LLVMFuncOp
func,
82 auto call = LLVM::CallOp::create(rewriter, loc,
func, args);
83 call.setCConv(
func.getCConv());
84 call.setConvergentAttr(
func.getConvergentAttr());
85 call.setNoUnwindAttr(
func.getNoUnwindAttr());
86 call.setWillReturnAttr(
func.getWillReturnAttr());
87 call.setMemoryEffectsAttr(
func.getMemoryEffectsAttr());
107 matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter)
const final {
109 constexpr StringLiteral funcName =
"_Z7barrierj";
111 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
112 assert(moduleOp &&
"Expecting module");
113 Type flagTy = rewriter.getI32Type();
114 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
115 LLVM::LLVMFuncOp func =
121 constexpr int64_t localMemFenceFlag = 1;
122 Location loc = op->getLoc();
124 LLVM::ConstantOp::create(rewriter, loc, flagTy, localMemFenceFlag);
143 LaunchConfigConversion(StringRef funcName, StringRef rootOpName,
144 MLIRContext *context,
145 const LLVMTypeConverter &typeConverter,
146 PatternBenefit benefit)
147 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
148 funcName(funcName) {}
150 virtual gpu::Dimension getDimension(Operation *op)
const = 0;
153 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
154 ConversionPatternRewriter &rewriter)
const final {
156 assert(moduleOp &&
"Expecting module");
157 Type dimTy = rewriter.getI32Type();
158 Type indexTy = getTypeConverter()->getIndexType();
163 Location loc = op->getLoc();
164 gpu::Dimension dim = getDimension(op);
165 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
166 static_cast<int64_t
>(dim));
174template <
typename SourceOp>
175struct LaunchConfigOpConversion final : LaunchConfigConversion {
176 static StringRef getFuncName();
178 explicit LaunchConfigOpConversion(
const LLVMTypeConverter &typeConverter,
179 PatternBenefit benefit = 1)
180 : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(),
184 gpu::Dimension getDimension(Operation *op)
const final {
185 return cast<SourceOp>(op).getDimension();
190StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() {
191 return "_Z12get_group_idj";
195StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() {
196 return "_Z14get_num_groupsj";
200StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() {
201 return "_Z14get_local_sizej";
205StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() {
206 return "_Z12get_local_idj";
210StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() {
211 return "_Z13get_global_idj";
230 static StringRef getBaseName(gpu::ShuffleMode mode) {
232 case gpu::ShuffleMode::IDX:
233 return "sub_group_shuffle";
234 case gpu::ShuffleMode::XOR:
235 return "sub_group_shuffle_xor";
236 case gpu::ShuffleMode::UP:
237 return "sub_group_shuffle_up";
238 case gpu::ShuffleMode::DOWN:
239 return "sub_group_shuffle_down";
241 llvm_unreachable(
"Unhandled shuffle mode");
244 static std::optional<StringRef> getTypeMangling(Type type) {
246 .Case<Float16Type>([](
auto) {
return "Dhj"; })
247 .Case<Float32Type>([](
auto) {
return "fj"; })
248 .Case<Float64Type>([](
auto) {
return "dj"; })
249 .Case<IntegerType>([](
auto intTy) -> std::optional<StringRef> {
250 switch (intTy.getWidth()) {
262 .Default(std::nullopt);
265 static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
267 StringRef baseName = getBaseName(mode);
268 std::optional<StringRef> typeMangling = getTypeMangling(type);
271 return llvm::formatv(
"_Z{}{}{}", baseName.size(), baseName,
272 typeMangling.value());
276 static std::optional<int> getSubgroupSize(Operation *op) {
280 return parentFunc.getIntelReqdSubGroupSize();
283 static bool hasValidWidth(gpu::ShuffleOp op) {
285 Value width = op.getWidth();
287 val == getSubgroupSize(op);
290 static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc,
291 ConversionPatternRewriter &rewriter) {
293 .Case([&](BFloat16Type) {
294 return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(),
297 .Case([&](IntegerType intTy) -> Value {
298 if (intTy.getWidth() == 1)
299 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(),
306 static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy,
308 ConversionPatternRewriter &rewriter) {
310 .Case([&](BFloat16Type) {
311 return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal);
313 .Case([&](IntegerType intTy) -> Value {
314 if (intTy.getWidth() == 1)
315 return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal);
322 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
323 ConversionPatternRewriter &rewriter)
const final {
324 if (!hasValidWidth(op))
325 return rewriter.notifyMatchFailure(
326 op,
"shuffle width and subgroup size mismatch");
328 Location loc = op->getLoc();
330 bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter);
331 std::optional<std::string> funcName =
332 getFuncName(op.getMode(), inValue.
getType());
334 return rewriter.notifyMatchFailure(op,
"unsupported value type");
337 assert(moduleOp &&
"Expecting module");
338 Type valueType = inValue.
getType();
339 Type offsetType = adaptor.getOffset().getType();
340 Type resultType = valueType;
342 moduleOp, funcName.value(), {valueType, offsetType}, resultType,
345 std::array<Value, 2> args{inValue, adaptor.getOffset()};
348 Value resultOrConversion =
349 bitcastOrTruncAfterShuffle(
result, op.getType(0), loc, rewriter);
352 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(),
true);
353 rewriter.replaceOp(op, {resultOrConversion, trueVal});
358class MemorySpaceToOpenCLMemorySpaceConverter final :
public TypeConverter {
360 MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) {
361 addConversion([](Type t) {
return t; });
362 addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> {
369 spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup);
370 Attribute addrSpaceAttr =
371 IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace);
372 if (
auto rankedType = dyn_cast<MemRefType>(memRefType)) {
373 return MemRefType::get(memRefType.
getShape(),
375 rankedType.getLayout(), addrSpaceAttr);
380 addConversion([
this](FunctionType type) {
381 auto inputs = llvm::map_to_vector(
382 type.getInputs(), [
this](Type ty) { return convertType(ty); });
383 auto results = llvm::map_to_vector(
384 type.getResults(), [
this](Type ty) { return convertType(ty); });
385 return FunctionType::get(type.getContext(), inputs, results);
394template <
typename SubgroupOp>
396 using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
400 matchAndRewrite(SubgroupOp op,
typename SubgroupOp::Adaptor adaptor,
401 ConversionPatternRewriter &rewriter)
const final {
402 constexpr StringRef funcName = [] {
403 if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
404 return "_Z16get_sub_group_id";
405 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
406 return "_Z22get_sub_group_local_id";
407 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
408 return "_Z18get_num_sub_groups";
409 }
else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
410 return "_Z18get_sub_group_size";
414 Operation *moduleOp =
415 op->template getParentWithTrait<OpTrait::SymbolTable>();
416 Type resultTy = rewriter.getI32Type();
417 LLVM::LLVMFuncOp func =
421 Location loc = op->getLoc();
424 Type indexTy = getTypeConverter()->getIndexType();
425 if (resultTy != indexTy) {
429 result = LLVM::ZExtOp::create(rewriter, loc, indexTy,
result);
432 rewriter.replaceOp(op,
result);
441struct GPUToLLVMSPVConversionPass final
445 void runOnOperation() final {
447 RewritePatternSet
patterns(context);
449 LowerToLLVMOptions
options(context);
450 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
451 LLVMTypeConverter converter(context,
options);
452 LLVMConversionTarget
target(*context);
456 MemorySpaceToOpenCLMemorySpaceConverter converter(context);
457 AttrTypeReplacer replacer;
459 -> std::optional<BaseMemRefType> {
460 return converter.convertType<BaseMemRefType>(origType);
469 target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
470 gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
471 gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
472 gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
473 gpu::ThreadIdOp, gpu::PrintfOp>();
477 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, 2,
478 LLVM::cconv::CConv::SPIR_FUNC,
479 "_Z6printfPU3AS2Kcz");
481 if (
failed(applyPartialConversion(getOperation(),
target,
495gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
496 constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL;
505 GPUSubgroupOpConversion<gpu::LaneIdOp>,
506 GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
507 GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
508 GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
509 LaunchConfigOpConversion<gpu::BlockDimOp>,
510 LaunchConfigOpConversion<gpu::BlockIdOp>,
511 LaunchConfigOpConversion<gpu::GlobalIdOp>,
512 LaunchConfigOpConversion<gpu::GridDimOp>,
513 LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
515 unsigned privateAddressSpace =
516 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private);
517 unsigned localAddressSpace =
518 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup);
519 OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context);
520 StringAttr kernelBlockSizeAttributeName =
521 LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName);
525 privateAddressSpace, localAddressSpace,
526 {}, kernelBlockSizeAttributeName,
527 LLVM::CConv::SPIR_KERNEL, LLVM::CConv::SPIR_FUNC,
533 gpuAddressSpaceToOCLAddressSpace);
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static llvm::ManagedStatic< PassManagerOptions > options
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for operation conversions targeting the LLVM IR dialect.
const LLVMTypeConverter * getTypeConverter() const
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type getType() const
Return the type of this value.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
AttrTypeReplacerBase.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateGpuToLLVMSPVConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
spirv::StorageClass addressSpaceToStorageClass(gpu::AddressSpace addressSpace)
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
llvm::TypeSwitch< T, ResultT > TypeSwitch