MLIR 23.0.0git
LowerGpuOpsToROCDLOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate ROCDLIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/Pass/Pass.h"
18
41
44
45namespace mlir {
46#define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
47#include "mlir/Conversion/Passes.h.inc"
48} // namespace mlir
49
50using namespace mlir;
51
52// Truncate or extend the result depending on the index bitwidth specified
53// by the LLVMTypeConverter options.
54static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
55 Location loc, Value value,
56 const LLVMTypeConverter &converter) {
57 int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
58 int64_t indexBitwidth = converter.getIndexTypeBitwidth();
59 auto indexBitwidthType =
60 IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
61 // TODO: use <=> in C++20.
62 if (indexBitwidth > intWidth) {
63 return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
64 }
65 if (indexBitwidth < intWidth) {
66 return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
67 }
68 return value;
69}
70
71/// Returns true if the given `gpu.func` can be safely called using the bare
72/// pointer calling convention.
73static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
74 bool canBeBare = true;
75 for (Type type : func.getArgumentTypes())
76 if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
77 canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
78 return canBeBare;
79}
80
81static Value getLaneId(RewriterBase &rewriter, Location loc) {
82 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
83 Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
84 Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
85 NamedAttribute noundef = rewriter.getNamedAttr(
86 LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
87 NamedAttribute lowRange = rewriter.getNamedAttr(
88 LLVM::LLVMDialect::getRangeAttrName(),
89 LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
90 APInt(32, 32)));
91 NamedAttribute highRange = rewriter.getNamedAttr(
92 LLVM::LLVMDialect::getRangeAttrName(),
93 LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
94 APInt(32, 64)));
95 Value mbcntLo = ROCDL::MbcntLoOp::create(
96 rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
97 /*res_attrs=*/
98 rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
99 Value laneId = ROCDL::MbcntHiOp::create(
100 rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
101 rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
102 return laneId;
103}
104
105/// Maximum number of threads per block dimension on AMD GPUs.
106static constexpr int64_t kMaxThreadsPerBlockDim = 1024;
107
108/// Emits a call to an OCKL block/grid size function corresponding to
109/// `indexKind` with argument `dim`, except that if the context around
110/// `contextOp` gives an exact size for that dimension, return that as
111/// an `i64` constant instead.
114 gpu::Dimension dim, Operation *contextOp,
115 std::optional<uint32_t> opUpperBound) {
116 Location loc = contextOp->getLoc();
117 MLIRContext *context = contextOp->getContext();
118
119 auto i32Ty = IntegerType::get(context, 32);
120 auto i64Ty = IntegerType::get(context, 64);
121
122 if (std::optional<uint32_t> knownDim =
123 gpu::getKnownDimensionSizeAround(contextOp, indexKind, dim))
124 return LLVM::ConstantOp::create(rewriter, loc,
125 rewriter.getI64IntegerAttr(*knownDim));
126
127 int32_t dimParam = static_cast<int32_t>(dim);
128
129 StringRef functionName;
130 switch (indexKind) {
131 case gpu::index_lowering::IndexKind::Block:
132 functionName = "__ockl_get_local_size";
133 break;
134 case gpu::index_lowering::IndexKind::Grid:
135 functionName = "__ockl_get_num_groups";
136 break;
137 case gpu::index_lowering::IndexKind::Cluster:
138 case gpu::index_lowering::IndexKind::Other:
139 llvm_unreachable("Not valid index kinds for ockl lookup");
140 }
141
142 // Declare the ockl function: i64 @functionName(i32).
143 auto fnType = LLVM::LLVMFunctionType::get(i64Ty, {i32Ty});
144 Operation *moduleOp = contextOp->getParentWithTrait<OpTrait::SymbolTable>();
145 LLVM::LLVMFuncOp funcOp =
146 getOrDefineFunction(moduleOp, loc, rewriter, functionName, fnType);
147
148 // Create the call.
149 Value dimConst = LLVM::ConstantOp::create(rewriter, loc, i32Ty, dimParam);
150 auto callOp =
151 LLVM::CallOp::create(rewriter, loc, funcOp, ValueRange{dimConst});
152
153 LLVM::ConstantRangeAttr range;
154 if (opUpperBound) {
155 range = LLVM::ConstantRangeAttr::get(
156 context, APInt(64, 1),
157 APInt(64, static_cast<uint64_t>(*opUpperBound) + 1));
158 } else if (indexKind == gpu::index_lowering::IndexKind::Block) {
159 // Set the hardware limit for block ranges as the bounds on block dim calls.
160 range = LLVM::ConstantRangeAttr::get(context, APInt(64, 1),
161 APInt(64, kMaxThreadsPerBlockDim + 1));
162 }
163 if (range) {
164 callOp.setResAttrsAttr(rewriter.getArrayAttr(rewriter.getDictionaryAttr(
165 rewriter.getNamedAttr(LLVM::LLVMDialect::getRangeAttrName(), range))));
166 }
167 return callOp.getResult();
168}
169
170static constexpr StringLiteral amdgcnDataLayout =
171 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
172 "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
173 "32-v32:"
174 "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
175 "64-S32-A5-G1-ni:7:8:9";
176
177namespace {
178
179/// Lowers gpu.block_dim / gpu.grid_dim to direct __ockl_get_local_size /
180/// __ockl_get_num_groups function calls.
181template <typename OpTy>
182struct GPUDimOpToOcklCall final : ConvertOpToLLVMPattern<OpTy> {
183 GPUDimOpToOcklCall(const LLVMTypeConverter &converter,
185 : ConvertOpToLLVMPattern<OpTy>(converter), indexKind(indexKind) {}
186
187 LogicalResult
188 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
189 ConversionPatternRewriter &rewriter) const override {
190 Location loc = op.getLoc();
191
192 std::optional<uint32_t> opUpperBound;
193 if (auto bound = op.getUpperBound())
194 opUpperBound = static_cast<uint32_t>(bound->getZExtValue());
195
196 Value ocklCall = getKnownOrOcklDim(rewriter, indexKind, op.getDimension(),
197 op, opUpperBound);
198 Value result = truncOrExtToLLVMType(rewriter, loc, ocklCall,
199 *this->getTypeConverter());
200 rewriter.replaceOp(op, result);
201 return success();
202 }
203
204private:
205 const gpu::index_lowering::IndexKind indexKind;
206};
207
208struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
210
211 LogicalResult
212 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
213 ConversionPatternRewriter &rewriter) const override {
214 Location loc = op.getLoc();
215 MLIRContext *context = rewriter.getContext();
216 // convert to:
217 // %mlo = call noundef range(i32 0, 32)
218 // @llvm.amdgcn.mbcnt.lo(-1, 0)
219 // followed by:
220 // %lid = call noundef range(i32 0, 64)
221 // @llvm.amdgcn.mbcnt.hi(-1, %mlo)
222
223 Value laneId = getLaneId(rewriter, loc);
224 // Truncate or extend the result depending on the index bitwidth specified
225 // by the LLVMTypeConverter options.
226 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
227 if (indexBitwidth > 32) {
228 laneId = LLVM::SExtOp::create(
229 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
230 } else if (indexBitwidth < 32) {
231 laneId = LLVM::TruncOp::create(
232 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
233 }
234 rewriter.replaceOp(op, {laneId});
235 return success();
236 }
237};
238
239struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
241
242 GPUSubgroupSizeOpToROCDL(const LLVMTypeConverter &converter,
243 amdgpu::Chipset chipset)
245 chipset(chipset) {}
246
247 LogicalResult
248 matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
249 ConversionPatternRewriter &rewriter) const override {
250 LLVM::ConstantRangeAttr bounds = nullptr;
251 bool isBeforeGfx10 = chipset.majorVersion < 10;
252 if (auto upperBoundAttr = op.getUpperBoundAttr()) {
253 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
254 /*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
255 /*upper=*/op.getUpperBoundAttr().getInt() + 1);
256 }
257 Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
258 rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
259 wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
260 *getTypeConverter());
261 rewriter.replaceOp(op, {wavefrontOp});
262 return success();
263 }
264
265 const amdgpu::Chipset chipset;
266};
267
268struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
270
271 GPUSubgroupIdOpToROCDL(const LLVMTypeConverter &converter,
272 amdgpu::Chipset chipset)
273 : ConvertOpToLLVMPattern<gpu::SubgroupIdOp>(converter), chipset(chipset) {
274 }
275
276 LogicalResult
277 matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override {
279 Location loc = op.getLoc();
280 auto int32Type = rewriter.getI32Type();
281
282 Value subgroupId;
283 if (chipset.majorVersion >= 12) {
284 // For gfx12+, use the hardware wave.id register directly.
285 LLVM::ConstantRangeAttr bounds;
286 if (auto upperBoundAttr = op.getUpperBoundAttr())
287 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
288 /*bitWidth=*/32, /*lower=*/0,
289 /*upper=*/upperBoundAttr.getInt());
290 subgroupId = ROCDL::WaveId::create(rewriter, loc, int32Type, bounds);
291 } else {
292 // For older architectures, compute:
293 // subgroup_id = linearized_thread_id / subgroup_size
294 // where linearized_thread_id = tid.x + dim.x * (tid.y + dim.y * tid.z)
295 auto tidX = ROCDL::ThreadIdXOp::create(rewriter, loc, int32Type);
296 auto tidY = ROCDL::ThreadIdYOp::create(rewriter, loc, int32Type);
297 auto tidZ = ROCDL::ThreadIdZOp::create(rewriter, loc, int32Type);
298 auto setBoundFromContext = [&](Operation *tidOp, gpu::Dimension dim) {
299 if (LLVM::ConstantRangeAttr range =
301 op, dim, std::nullopt,
302 gpu::index_lowering::IndexKind::Block,
304 tidOp->setAttr("range", range);
305 };
306 setBoundFromContext(tidX, gpu::Dimension::x);
307 setBoundFromContext(tidY, gpu::Dimension::y);
308 setBoundFromContext(tidZ, gpu::Dimension::z);
309
310 auto flags =
311 LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
312
313 auto getBlockDim = [&](gpu::Dimension dim) {
314 Value dim64 =
315 getKnownOrOcklDim(rewriter, gpu::index_lowering::IndexKind::Block,
316 dim, op, std::nullopt);
317 Value dimTrunc =
318 LLVM::TruncOp::create(rewriter, loc, int32Type, dim64, flags);
319 return dimTrunc;
320 };
321 Value dimX = getBlockDim(gpu::Dimension::x);
322 Value dimY = getBlockDim(gpu::Dimension::y);
323
324 // linearized = tid.x + dim.x * (tid.y + dim.y * tid.z)
325 // Thread IDs and dimensions are non-negative and small, so use nuw+nsw.
326 Value dimYxTidZ =
327 LLVM::MulOp::create(rewriter, loc, int32Type, dimY, tidZ, flags);
328 Value tidYPlusDimYxTidZ =
329 LLVM::AddOp::create(rewriter, loc, int32Type, tidY, dimYxTidZ, flags);
330 Value dimXxInner = LLVM::MulOp::create(rewriter, loc, int32Type, dimX,
331 tidYPlusDimYxTidZ, flags);
332 Value linearized = LLVM::AddOp::create(rewriter, loc, int32Type, tidX,
333 dimXxInner, flags);
334
335 Value subgroupSize =
336 ROCDL::WavefrontSizeOp::create(rewriter, loc, int32Type);
337 subgroupId = LLVM::UDivOp::create(rewriter, loc, int32Type, linearized,
338 subgroupSize);
339 }
340
341 subgroupId =
342 truncOrExtToLLVMType(rewriter, loc, subgroupId, *getTypeConverter());
343 rewriter.replaceOp(op, subgroupId);
344 return success();
345 }
346
347 const amdgpu::Chipset chipset;
348};
349
350static bool isSupportedReadLaneType(Type type) {
351 // https://llvm.org/docs/AMDGPUUsage.html#llvm-ir-intrinsics
352 if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type,
353 LLVM::LLVMPointerType>(type))
354 return true;
355
356 if (auto intType = dyn_cast<IntegerType>(type))
357 return llvm::is_contained({16, 32, 64},
358 static_cast<int>(intType.getWidth()));
359
360 if (auto vecType = dyn_cast<VectorType>(type)) {
361 Type elementType = vecType.getElementType();
362 if (elementType.isInteger(32))
363 return true;
364
365 if (vecType.getNumElements() == 2 &&
366 (isa<Float16Type, BFloat16Type>(elementType) ||
367 elementType.isInteger(16)))
368 return true;
369 }
370
371 return false;
372}
373
374struct GPUSubgroupBroadcastOpToROCDL
375 : public ConvertOpToLLVMPattern<gpu::SubgroupBroadcastOp> {
377
378 LogicalResult
379 matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
380 ConversionPatternRewriter &rewriter) const override {
381 Value src = adaptor.getSrc();
382 if (isSupportedReadLaneType(src.getType())) {
383 Value result = createReadlaneOp(op, adaptor, rewriter, src);
384 rewriter.replaceOp(op, result);
385 return success();
386 }
387
388 Type i32 = rewriter.getI32Type();
389 Location loc = op.getLoc();
390 SmallVector<Value> decomposed;
391 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed,
392 /*permitVariablySizedScalars=*/true)))
393 return rewriter.notifyMatchFailure(op,
394 "Unexpected decomposition failure");
395
396 SmallVector<Value> results;
397 results.reserve(decomposed.size());
398 for (Value v : decomposed)
399 results.emplace_back(createReadlaneOp(op, adaptor, rewriter, v));
400
401 Value result = LLVM::composeValue(rewriter, loc, results, src.getType());
402 rewriter.replaceOp(op, result);
403 return success();
404 }
405
406private:
407 static Value createReadlaneOp(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
408 ConversionPatternRewriter &rewriter,
409 Value src) {
410 if (adaptor.getBroadcastType() == gpu::BroadcastType::specific_lane) {
411 return ROCDL::ReadlaneOp::create(rewriter, op.getLoc(), src.getType(),
412 src, adaptor.getLane());
413 } else { // first_active_lane
414 return ROCDL::ReadfirstlaneOp::create(rewriter, op.getLoc(),
415 src.getType(), src);
416 }
417 }
418};
419
420struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
422
423 /// Lowers a shuffle to the corresponding ROCDL ops.
424 ///
425 /// Use the `width` argument to see if src lane is participating.
426 /// If not the dstLane would be itself.
427 ///
428 /// Shuffle with DS Bpermute:
429 /// let shflMode = [xor, up, down, idx]
430 /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
431 /// 1. curLaneId = using mbcnt.lo + mbcnt.hi
432 /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width
433 /// 3. dstLane = shflMode(curLaneId, step)
434 /// 4. isActiveSrcLane = dstLane < isActiveSrcLane
435 /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId
436 /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
437 /// 7. bpermute(dwordAlignedDstLane, shfl_value).
438 ///
439 LogicalResult
440 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
441 ConversionPatternRewriter &rewriter) const override {
442 Location loc = op->getLoc();
443 Value initShflValue = adaptor.getValue();
444
445 Value srcLaneId = getLaneId(rewriter, loc);
446
447 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
448 Value width = adaptor.getWidth();
449 Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
450 Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
451 Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
452 Value widthOrZeroIfOutside =
453 LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth);
454 Value dstLane;
455
456 switch (op.getMode()) {
457 case gpu::ShuffleMode::UP:
458 dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
459 adaptor.getOffset());
460 break;
461 case gpu::ShuffleMode::DOWN:
462 dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
463 adaptor.getOffset());
464 break;
465 case gpu::ShuffleMode::XOR:
466 dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
467 adaptor.getOffset());
468 break;
469 case gpu::ShuffleMode::IDX:
470 dstLane = adaptor.getOffset();
471 break;
472 }
473 Value isActiveSrcLane = LLVM::ICmpOp::create(
474 rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
475 Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
476 dstLane, srcLaneId);
477 Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
478 Value dwordAlignedDstLane =
479 LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
480
481 SmallVector<Value> decomposed;
482 if (failed(LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type,
483 decomposed)))
484 return rewriter.notifyMatchFailure(op,
485 "failed to decompose value to i32");
486 SmallVector<Value> swizzled;
487 for (Value v : decomposed) {
488 Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
489 dwordAlignedDstLane, v);
490 swizzled.emplace_back(res);
491 }
492 Value shflValue =
493 LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
494 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
495 return success();
496 }
497};
498
499struct GPUBarrierOpLowering final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
500 GPUBarrierOpLowering(const LLVMTypeConverter &converter,
501 amdgpu::Chipset chipset)
502 : ConvertOpToLLVMPattern<gpu::BarrierOp>(converter), chipset(chipset) {}
503
504 amdgpu::Chipset chipset;
505
506 LogicalResult
507 matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor,
508 ConversionPatternRewriter &rewriter) const override {
509 Location loc = op.getLoc();
510
511 // Analyze the address_spaces attribute to determine fence behavior.
512 bool fenceGlobal = false;
513 bool fenceLDS = false;
514 std::optional<ArrayAttr> addrSpacesToFence = op.getAddressSpaces();
515
516 if (addrSpacesToFence) {
517 for (auto spaceAttr :
518 addrSpacesToFence->getAsRange<gpu::AddressSpaceAttr>()) {
519 switch (spaceAttr.getValue()) {
520 case gpu::AddressSpace::Global:
521 fenceGlobal = true;
522 break;
523 case gpu::AddressSpace::Workgroup:
524 fenceLDS = true;
525 break;
526 case gpu::AddressSpace::Private:
527 break;
528 }
529 }
530 } else {
531 // Default semantics match __syncthreads() and fence both global and LDS.
532 fenceGlobal = true;
533 fenceLDS = true;
534 }
535
536 Attribute mmra;
537 if (fenceLDS && !fenceGlobal) {
538 mmra =
539 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
540 } else if (fenceGlobal && !fenceLDS) {
541 mmra = rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as",
542 "global");
543 }
544
545 constexpr llvm::StringLiteral scope = "workgroup";
546
547 bool emitFences = fenceGlobal || fenceLDS;
548 // Emit release fence if needed.
549 if (emitFences) {
550 auto relFence = LLVM::FenceOp::create(
551 rewriter, loc, LLVM::AtomicOrdering::release, scope);
552 if (mmra)
553 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
554 mmra);
555 }
556
557 if (chipset.majorVersion < 12) {
558 ROCDL::SBarrierOp::create(rewriter, loc);
559 } else {
560 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
561 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
562 }
563
564 if (emitFences) {
565 auto acqFence = LLVM::FenceOp::create(
566 rewriter, loc, LLVM::AtomicOrdering::acquire, scope);
567 if (mmra)
568 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
569 mmra);
570 }
571
572 rewriter.eraseOp(op);
573 return success();
574 }
575};
576
577/// Import the GPU Ops to ROCDL Patterns.
578#include "GPUToROCDL.cpp.inc"
579
580// A pass that replaces all occurrences of GPU device operations with their
581// corresponding ROCDL equivalent.
582//
583// This pass only handles device code and is not meant to be run on GPU host
584// code.
585struct LowerGpuOpsToROCDLOpsPass final
586 : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
587 using Base::Base;
588
589 void getDependentDialects(DialectRegistry &registry) const override {
590 Base::getDependentDialects(registry);
592 }
593
594 void runOnOperation() override {
595 gpu::GPUModuleOp m = getOperation();
596 MLIRContext *ctx = m.getContext();
597
598 auto llvmDataLayout = m->getAttrOfType<StringAttr>(
599 LLVM::LLVMDialect::getDataLayoutAttrName());
600 if (!llvmDataLayout) {
601 llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout);
602 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
603 }
604 // Request C wrapper emission.
605 for (auto func : m.getOps<func::FuncOp>()) {
606 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
607 UnitAttr::get(ctx));
608 }
609
610 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
611 if (failed(maybeChipset)) {
612 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
613 return signalPassFailure();
614 }
615
616 /// Customize the bitwidth used for the device side index computations.
618 ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
619 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
620 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
621 options.overrideIndexBitwidth(indexBitwidth);
622
623 if (useBarePtrCallConv) {
624 options.useBarePtrCallConv = true;
625 WalkResult canUseBarePointers =
626 m.walk([](gpu::GPUFuncOp func) -> WalkResult {
628 return WalkResult::advance();
629 return WalkResult::interrupt();
630 });
631 if (canUseBarePointers.wasInterrupted()) {
632 emitError(UnknownLoc::get(ctx),
633 "bare pointer calling convention requires all memrefs to "
634 "have static shape and use the identity map");
635 return signalPassFailure();
636 }
637 }
638
639 // Apply in-dialect lowering. In-dialect lowering will replace
640 // ops which need to be lowered further, which is not supported by a
641 // single conversion pass.
642 {
643 RewritePatternSet patterns(ctx);
645 populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
646 (void)applyPatternsGreedily(m, std::move(patterns));
647 }
648
649 LLVMTypeConverter converter(ctx, options);
651
652 RewritePatternSet llvmPatterns(ctx);
654
655 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
656 allowedDialects.end());
657 for (Dialect *dialect : ctx->getLoadedDialects()) {
658 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
659 // Empty `allowedDialectsSet` means all dialects are allowed.
660 if (!allowedDialectsSet.empty() && !allowed)
661 continue;
662
663 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
664 if (!iface) {
665 // Error out if dialect was explicily specified but doesn't implement
666 // conversion interface.
667 if (allowed) {
668 m.emitError()
669 << "dialect does not implement ConvertToLLVMPatternInterface: "
670 << dialect->getNamespace();
671 return signalPassFailure();
672 }
673 continue;
674 }
675
676 iface->populateConvertToLLVMConversionPatterns(target, converter,
677 llvmPatterns);
678 }
679
680 populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
681 *maybeChipset);
682 populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime,
683 *maybeChipset);
685 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
686 signalPassFailure();
687 auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
688 auto reqdWorkGroupSizeAttrHelper =
689 rocdlDialect->getReqdWorkGroupSizeAttrHelper();
690 auto flatWorkGroupSizeAttrHelper =
691 rocdlDialect->getFlatWorkGroupSizeAttrHelper();
692 // Manually rewrite known block size attributes so the LLVMIR translation
693 // infrastructure can pick them up.
694 m.walk([&](LLVM::LLVMFuncOp op) {
695 if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
696 auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
697 // Also set up the rocdl.flat_work_group_size attribute to prevent
698 // conflicting metadata.
699 uint32_t flatSize = 1;
700 for (uint32_t size : blockSizes.asArrayRef()) {
701 flatSize *= size;
702 }
703 StringAttr flatSizeAttr =
704 StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
705 flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
706 }
707 });
708 }
709};
710
711} // namespace
712
714 target.addIllegalOp<func::FuncOp>();
715 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
716 target.addLegalDialect<ROCDL::ROCDLDialect>();
717 target.addIllegalDialect<gpu::GPUDialect>();
718 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
719 LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
720 LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
721 // These ops are legal for f32 type.
722 target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
723 return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
724 });
725 // TODO: Remove once we support replacing non-root ops.
726 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
727}
728
730 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
735 auto *rocdlDialect =
736 converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
737 populateWithGenerated(patterns);
738 patterns.add<
739 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
740 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
741 converter, IndexKind::Block, IntrType::Id);
743 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
744 converter, IndexKind::Grid, IntrType::Id);
745 patterns.add<GPUDimOpToOcklCall<gpu::BlockDimOp>>(converter,
746 IndexKind::Block);
747 patterns.add<GPUDimOpToOcklCall<gpu::GridDimOp>>(converter, IndexKind::Grid);
748 patterns.add<GPUReturnOpLowering>(converter);
749 patterns.add<GPUFuncOpLowering>(
750 converter,
752 /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
753 /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
754 rocdlDialect->getKernelAttrHelper().getName(),
755 rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName(),
756 /*kernelClusterSizeAttributeName=*/{}});
757 if (Runtime::HIP == runtime) {
758 patterns.add<GPUPrintfOpToHIPLowering>(converter);
759 } else if (Runtime::OpenCL == runtime) {
760 // Use address space = 4 to match the OpenCL definition of printf()
761 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
762 }
763 // TODO: Add alignment for workgroup memory
764 patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
765
766 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
767 GPUSubgroupBroadcastOpToROCDL>(converter);
768 patterns.add<GPUSubgroupIdOpToROCDL, GPUSubgroupSizeOpToROCDL,
769 GPUBarrierOpLowering>(converter, chipset);
770
771 populateMathToROCDLConversionPatterns(converter, patterns, chipset);
772}
return success()
b getContext())
static Value getLaneId(RewriterBase &rewriter, Location loc)
static constexpr int64_t kMaxThreadsPerBlockDim
Maximum number of threads per block dimension on AMD GPUs.
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func)
Returns true if the given gpu.func can be safely called using the bare pointer calling convention.
static constexpr StringLiteral amdgcnDataLayout
static Value getKnownOrOcklDim(RewriterBase &rewriter, gpu::index_lowering::IndexKind indexKind, gpu::Dimension dim, Operation *contextOp, std::optional< uint32_t > opUpperBound)
Emits a call to an OCKL block/grid size function corresponding to indexKind with argument dim,...
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, Location loc, Value value, const LLVMTypeConverter &converter)
static llvm::ManagedStatic< PassManagerOptions > options
#define add(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:277
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:611
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType, SmallVectorImpl< Value > &result, bool permitVariablySizedScalars=false)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:495
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:594
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Runtime
Potential runtimes for AMD GPU kernels.
Definition Runtimes.h:15
LLVM::ConstantRangeAttr getIndexOpRange(Operation *op, gpu::Dimension dim, std::optional< uint32_t > opUpperBound, IndexKind indexKind, IntrType intrType, unsigned bitWidth)
Returns a ConstantRangeAttr for a GPU index op, or nullptr if no bounds are found.
std::optional< uint32_t > getKnownDimensionSizeAround(Operation *op, DimensionKind kind, Dimension dim)
Retrieve the constant bounds for a given dimension and dimension kind from the context surrounding op...
Include the generated interface declarations.
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime, amdgpu::Chipset chipset)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, std::optional< amdgpu::Chipset > chipset)
Populate the given list with patterns that convert from Math to ROCDL calls.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void configureGpuToROCDLConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to ROCDL.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
The lowering of gpu.printf to a call to HIP hostcalls.
The lowering of gpu.printf to a call to an external printf() function.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14