MLIR 23.0.0git
LowerGpuOpsToROCDLOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate ROCDLIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/Pass/Pass.h"
18
41
44
45namespace mlir {
46#define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
47#include "mlir/Conversion/Passes.h.inc"
48} // namespace mlir
49
50using namespace mlir;
51
52// Truncate or extend the result depending on the index bitwidth specified
53// by the LLVMTypeConverter options.
54static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
55 Location loc, Value value,
56 const LLVMTypeConverter &converter) {
57 int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
58 int64_t indexBitwidth = converter.getIndexTypeBitwidth();
59 auto indexBitwidthType =
60 IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
61 // TODO: use <=> in C++20.
62 if (indexBitwidth > intWidth) {
63 return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
64 }
65 if (indexBitwidth < intWidth) {
66 return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
67 }
68 return value;
69}
70
71/// Returns true if the given `gpu.func` can be safely called using the bare
72/// pointer calling convention.
73static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
74 bool canBeBare = true;
75 for (Type type : func.getArgumentTypes())
76 if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
77 canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
78 return canBeBare;
79}
80
81static Value getLaneId(RewriterBase &rewriter, Location loc) {
82 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
83 Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
84 Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
85 NamedAttribute noundef = rewriter.getNamedAttr(
86 LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
87 NamedAttribute lowRange = rewriter.getNamedAttr(
88 LLVM::LLVMDialect::getRangeAttrName(),
89 LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
90 APInt(32, 32)));
91 NamedAttribute highRange = rewriter.getNamedAttr(
92 LLVM::LLVMDialect::getRangeAttrName(),
93 LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
94 APInt(32, 64)));
95 Value mbcntLo = ROCDL::MbcntLoOp::create(
96 rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
97 /*res_attrs=*/
98 rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
99 Value laneId = ROCDL::MbcntHiOp::create(
100 rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
101 rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
102 return laneId;
103}
104
105/// Maximum number of threads per block dimension on AMD GPUs.
106static constexpr int64_t kMaxThreadsPerBlockDim = 1024;
107
108/// Emits a call to an OCKL block/grid size function corresponding to
109/// `indexKind` with argument `dim`, except that if the context around
110/// `contextOp` gives an exact size for that dimension, return that as
111/// an `i64` constant instead.
114 gpu::Dimension dim, Operation *contextOp,
115 std::optional<uint32_t> opUpperBound) {
116 Location loc = contextOp->getLoc();
117 MLIRContext *context = contextOp->getContext();
118
119 auto i32Ty = IntegerType::get(context, 32);
120 auto i64Ty = IntegerType::get(context, 64);
121
122 if (std::optional<uint32_t> knownDim =
123 gpu::getKnownDimensionSizeAround(contextOp, indexKind, dim))
124 return LLVM::ConstantOp::create(rewriter, loc,
125 rewriter.getI64IntegerAttr(*knownDim));
126
127 int32_t dimParam = static_cast<int32_t>(dim);
128
129 StringRef functionName;
130 switch (indexKind) {
131 case gpu::index_lowering::IndexKind::Block:
132 functionName = "__ockl_get_local_size";
133 break;
134 case gpu::index_lowering::IndexKind::Grid:
135 functionName = "__ockl_get_num_groups";
136 break;
137 case gpu::index_lowering::IndexKind::Cluster:
138 case gpu::index_lowering::IndexKind::Other:
139 llvm_unreachable("Not valid index kinds for ockl lookup");
140 }
141
142 // Declare the ockl function: i64 @functionName(i32).
143 auto fnType = LLVM::LLVMFunctionType::get(i64Ty, {i32Ty});
144 Operation *moduleOp = contextOp->getParentWithTrait<OpTrait::SymbolTable>();
145 LLVM::LLVMFuncOp funcOp =
146 getOrDefineFunction(moduleOp, loc, rewriter, functionName, fnType);
147
148 // Create the call.
149 Value dimConst = LLVM::ConstantOp::create(rewriter, loc, i32Ty, dimParam);
150 auto callOp =
151 LLVM::CallOp::create(rewriter, loc, funcOp, ValueRange{dimConst});
152
153 LLVM::ConstantRangeAttr range;
154 if (opUpperBound) {
155 range = LLVM::ConstantRangeAttr::get(
156 context, APInt(64, 1),
157 APInt(64, static_cast<uint64_t>(*opUpperBound) + 1));
158 } else if (indexKind == gpu::index_lowering::IndexKind::Block) {
159 // Set the hardware limit for block ranges as the bounds on block dim calls.
160 range = LLVM::ConstantRangeAttr::get(context, APInt(64, 1),
161 APInt(64, kMaxThreadsPerBlockDim + 1));
162 }
163 if (range) {
164 callOp.setResAttrsAttr(rewriter.getArrayAttr(rewriter.getDictionaryAttr(
165 rewriter.getNamedAttr(LLVM::LLVMDialect::getRangeAttrName(), range))));
166 }
167 return callOp.getResult();
168}
169
170static constexpr StringLiteral amdgcnDataLayout =
171 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
172 "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
173 "32-v32:"
174 "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
175 "64-S32-A5-G1-ni:7:8:9";
176
177namespace {
178
179/// Lowers gpu.block_dim / gpu.grid_dim to direct __ockl_get_local_size /
180/// __ockl_get_num_groups function calls.
181template <typename OpTy>
182struct GPUDimOpToOcklCall final : ConvertOpToLLVMPattern<OpTy> {
183 GPUDimOpToOcklCall(const LLVMTypeConverter &converter,
185 : ConvertOpToLLVMPattern<OpTy>(converter), indexKind(indexKind) {}
186
187 LogicalResult
188 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
189 ConversionPatternRewriter &rewriter) const override {
190 Location loc = op.getLoc();
191
192 std::optional<uint32_t> opUpperBound;
193 if (auto bound = op.getUpperBound())
194 opUpperBound = static_cast<uint32_t>(bound->getZExtValue());
195
196 Value ocklCall = getKnownOrOcklDim(rewriter, indexKind, op.getDimension(),
197 op, opUpperBound);
198 Value result = truncOrExtToLLVMType(rewriter, loc, ocklCall,
199 *this->getTypeConverter());
200 rewriter.replaceOp(op, result);
201 return success();
202 }
203
204private:
205 const gpu::index_lowering::IndexKind indexKind;
206};
207
208struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
210
211 LogicalResult
212 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
213 ConversionPatternRewriter &rewriter) const override {
214 Location loc = op.getLoc();
215 MLIRContext *context = rewriter.getContext();
216 // convert to:
217 // %mlo = call noundef range(i32 0, 32)
218 // @llvm.amdgcn.mbcnt.lo(-1, 0)
219 // followed by:
220 // %lid = call noundef range(i32 0, 64)
221 // @llvm.amdgcn.mbcnt.hi(-1, %mlo)
222
223 Value laneId = getLaneId(rewriter, loc);
224 // Truncate or extend the result depending on the index bitwidth specified
225 // by the LLVMTypeConverter options.
226 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
227 if (indexBitwidth > 32) {
228 laneId = LLVM::SExtOp::create(
229 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
230 } else if (indexBitwidth < 32) {
231 laneId = LLVM::TruncOp::create(
232 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
233 }
234 rewriter.replaceOp(op, {laneId});
235 return success();
236 }
237};
238
239struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
241
242 GPUSubgroupSizeOpToROCDL(const LLVMTypeConverter &converter,
243 amdgpu::Chipset chipset)
245 chipset(chipset) {}
246
247 LogicalResult
248 matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
249 ConversionPatternRewriter &rewriter) const override {
250 LLVM::ConstantRangeAttr bounds = nullptr;
251 bool isBeforeGfx10 = chipset.majorVersion < 10;
252 if (auto upperBoundAttr = op.getUpperBoundAttr()) {
253 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
254 /*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
255 /*upper=*/op.getUpperBoundAttr().getInt() + 1);
256 }
257 Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
258 rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
259 wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
260 *getTypeConverter());
261 rewriter.replaceOp(op, {wavefrontOp});
262 return success();
263 }
264
265 const amdgpu::Chipset chipset;
266};
267
268struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
270
271 GPUSubgroupIdOpToROCDL(const LLVMTypeConverter &converter,
272 amdgpu::Chipset chipset)
273 : ConvertOpToLLVMPattern<gpu::SubgroupIdOp>(converter), chipset(chipset) {
274 }
275
276 LogicalResult
277 matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override {
279 Location loc = op.getLoc();
280 auto int32Type = rewriter.getI32Type();
281
282 Value subgroupId;
283 if (chipset.majorVersion >= 12) {
284 // For gfx12+, use the hardware wave.id register directly.
285 LLVM::ConstantRangeAttr bounds;
286 if (auto upperBoundAttr = op.getUpperBoundAttr())
287 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
288 /*bitWidth=*/32, /*lower=*/0,
289 /*upper=*/upperBoundAttr.getInt());
290 subgroupId = ROCDL::WaveId::create(rewriter, loc, int32Type, bounds);
291 } else {
292 // For older architectures, compute:
293 // subgroup_id = linearized_thread_id / subgroup_size
294 // where linearized_thread_id = tid.x + dim.x * (tid.y + dim.y * tid.z)
295 auto tidX = ROCDL::ThreadIdXOp::create(rewriter, loc, int32Type);
296 auto tidY = ROCDL::ThreadIdYOp::create(rewriter, loc, int32Type);
297 auto tidZ = ROCDL::ThreadIdZOp::create(rewriter, loc, int32Type);
298 auto setBoundFromContext = [&](Operation *tidOp, gpu::Dimension dim) {
299 if (LLVM::ConstantRangeAttr range =
301 op, dim, std::nullopt,
302 gpu::index_lowering::IndexKind::Block,
304 tidOp->setAttr("range", range);
305 };
306 setBoundFromContext(tidX, gpu::Dimension::x);
307 setBoundFromContext(tidY, gpu::Dimension::y);
308 setBoundFromContext(tidZ, gpu::Dimension::z);
309
310 auto flags =
311 LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
312
313 auto getBlockDim = [&](gpu::Dimension dim) {
314 Value dim64 =
315 getKnownOrOcklDim(rewriter, gpu::index_lowering::IndexKind::Block,
316 dim, op, std::nullopt);
317 Value dimTrunc =
318 LLVM::TruncOp::create(rewriter, loc, int32Type, dim64, flags);
319 return dimTrunc;
320 };
321 Value dimX = getBlockDim(gpu::Dimension::x);
322 Value dimY = getBlockDim(gpu::Dimension::y);
323
324 // linearized = tid.x + dim.x * (tid.y + dim.y * tid.z)
325 // Thread IDs and dimensions are non-negative and small, so use nuw+nsw.
326 Value dimYxTidZ =
327 LLVM::MulOp::create(rewriter, loc, int32Type, dimY, tidZ, flags);
328 Value tidYPlusDimYxTidZ =
329 LLVM::AddOp::create(rewriter, loc, int32Type, tidY, dimYxTidZ, flags);
330 Value dimXxInner = LLVM::MulOp::create(rewriter, loc, int32Type, dimX,
331 tidYPlusDimYxTidZ, flags);
332 Value linearized = LLVM::AddOp::create(rewriter, loc, int32Type, tidX,
333 dimXxInner, flags);
334
335 Value subgroupSize =
336 ROCDL::WavefrontSizeOp::create(rewriter, loc, int32Type);
337 subgroupId = LLVM::UDivOp::create(rewriter, loc, int32Type, linearized,
338 subgroupSize);
339 }
340
341 subgroupId =
342 truncOrExtToLLVMType(rewriter, loc, subgroupId, *getTypeConverter());
343 rewriter.replaceOp(op, subgroupId);
344 return success();
345 }
346
347 const amdgpu::Chipset chipset;
348};
349
350static bool isSupportedReadLaneType(Type type) {
351 // https://llvm.org/docs/AMDGPUUsage.html#llvm-ir-intrinsics
352 if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type,
353 LLVM::LLVMPointerType>(type))
354 return true;
355
356 if (auto intType = dyn_cast<IntegerType>(type))
357 return llvm::is_contained({16, 32, 64},
358 static_cast<int>(intType.getWidth()));
359
360 if (auto vecType = dyn_cast<VectorType>(type)) {
361 Type elementType = vecType.getElementType();
362 if (elementType.isInteger(32))
363 return true;
364
365 if (vecType.getNumElements() == 2 &&
366 (isa<Float16Type, BFloat16Type>(elementType) ||
367 elementType.isInteger(16)))
368 return true;
369 }
370
371 return false;
372}
373
374struct GPUSubgroupBroadcastOpToROCDL
375 : public ConvertOpToLLVMPattern<gpu::SubgroupBroadcastOp> {
377
378 LogicalResult
379 matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
380 ConversionPatternRewriter &rewriter) const override {
381 Value src = adaptor.getSrc();
382 if (isSupportedReadLaneType(src.getType())) {
383 Value result = createReadlaneOp(op, adaptor, rewriter, src);
384 rewriter.replaceOp(op, result);
385 return success();
386 }
387
388 Type i32 = rewriter.getI32Type();
389 Location loc = op.getLoc();
390 SmallVector<Value> decomposed;
391 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed,
392 /*permitVariablySizedScalars=*/true)))
393 return rewriter.notifyMatchFailure(op,
394 "Unexpected decomposition failure");
395
396 SmallVector<Value> results;
397 results.reserve(decomposed.size());
398 for (Value v : decomposed)
399 results.emplace_back(createReadlaneOp(op, adaptor, rewriter, v));
400
401 Value result = LLVM::composeValue(rewriter, loc, results, src.getType());
402 rewriter.replaceOp(op, result);
403 return success();
404 }
405
406private:
407 static Value createReadlaneOp(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
408 ConversionPatternRewriter &rewriter,
409 Value src) {
410 if (adaptor.getBroadcastType() == gpu::BroadcastType::specific_lane) {
411 return ROCDL::ReadlaneOp::create(rewriter, op.getLoc(), src.getType(),
412 src, adaptor.getLane());
413 } else { // first_active_lane
414 return ROCDL::ReadfirstlaneOp::create(rewriter, op.getLoc(),
415 src.getType(), src);
416 }
417 }
418};
419
420struct GPUBallotOpToROCDL : public ConvertOpToLLVMPattern<gpu::BallotOp> {
422
423 LogicalResult
424 matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
425 ConversionPatternRewriter &rewriter) const override {
426 auto intType = cast<IntegerType>(op.getType());
427 unsigned width = intType.getWidth();
428
429 // ROCDL ballot natively supports i32 and i64 for wavefront sizes of
430 // 32 and 64 lanes.
431 if (width != 32 && width != 64)
432 return rewriter.notifyMatchFailure(
433 op, "rocdl.ballot only supports i32 and i64 result types");
434
435 rewriter.replaceOpWithNewOp<ROCDL::BallotOp>(op, op.getType(),
436 adaptor.getPredicate());
437 return success();
438 }
439};
440
441struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
443
444 /// Lowers a shuffle to the corresponding ROCDL ops.
445 ///
446 /// Use the `width` argument to see if src lane is participating.
447 /// If not the dstLane would be itself.
448 ///
449 /// Shuffle with DS Bpermute:
450 /// let shflMode = [xor, up, down, idx]
451 /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
452 /// 1. curLaneId = using mbcnt.lo + mbcnt.hi
453 /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width
454 /// 3. dstLane = shflMode(curLaneId, step)
455 /// 4. isActiveSrcLane = dstLane < isActiveSrcLane
456 /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId
457 /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
458 /// 7. bpermute(dwordAlignedDstLane, shfl_value).
459 ///
460 LogicalResult
461 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
462 ConversionPatternRewriter &rewriter) const override {
463 Location loc = op->getLoc();
464 Value initShflValue = adaptor.getValue();
465
466 Value srcLaneId = getLaneId(rewriter, loc);
467
468 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
469 Value width = adaptor.getWidth();
470 Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
471 Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
472 Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
473 Value widthOrZeroIfOutside =
474 LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth);
475 Value dstLane;
476
477 switch (op.getMode()) {
478 case gpu::ShuffleMode::UP:
479 dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
480 adaptor.getOffset());
481 break;
482 case gpu::ShuffleMode::DOWN:
483 dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
484 adaptor.getOffset());
485 break;
486 case gpu::ShuffleMode::XOR:
487 dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
488 adaptor.getOffset());
489 break;
490 case gpu::ShuffleMode::IDX:
491 dstLane = adaptor.getOffset();
492 break;
493 }
494 Value isActiveSrcLane = LLVM::ICmpOp::create(
495 rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
496 Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
497 dstLane, srcLaneId);
498 Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
499 Value dwordAlignedDstLane =
500 LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
501
502 SmallVector<Value> decomposed;
503 if (failed(LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type,
504 decomposed)))
505 return rewriter.notifyMatchFailure(op,
506 "failed to decompose value to i32");
507 SmallVector<Value> swizzled;
508 for (Value v : decomposed) {
509 Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
510 dwordAlignedDstLane, v);
511 swizzled.emplace_back(res);
512 }
513 Value shflValue =
514 LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
515 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
516 return success();
517 }
518};
519
520struct GPUBarrierOpLowering final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
521 GPUBarrierOpLowering(const LLVMTypeConverter &converter,
522 amdgpu::Chipset chipset)
523 : ConvertOpToLLVMPattern<gpu::BarrierOp>(converter), chipset(chipset) {}
524
525 amdgpu::Chipset chipset;
526
527 LogicalResult
528 matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor,
529 ConversionPatternRewriter &rewriter) const override {
530 Location loc = op.getLoc();
531
532 // Analyze the address_spaces attribute to determine fence behavior.
533 bool fenceGlobal = false;
534 bool fenceLDS = false;
535 std::optional<ArrayAttr> addrSpacesToFence = op.getAddressSpaces();
536
537 if (addrSpacesToFence) {
538 for (auto spaceAttr :
539 addrSpacesToFence->getAsRange<gpu::AddressSpaceAttr>()) {
540 switch (spaceAttr.getValue()) {
541 case gpu::AddressSpace::Global:
542 fenceGlobal = true;
543 break;
544 case gpu::AddressSpace::Workgroup:
545 fenceLDS = true;
546 break;
547 case gpu::AddressSpace::Private:
548 case gpu::AddressSpace::Constant:
549 // Private is thread-local, constant is read-only; no fencing needed.
550 break;
551 }
552 }
553 } else {
554 // Default semantics match __syncthreads() and fence both global and LDS.
555 fenceGlobal = true;
556 fenceLDS = true;
557 }
558
559 Attribute mmra;
560 if (fenceLDS && !fenceGlobal) {
561 mmra =
562 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
563 } else if (fenceGlobal && !fenceLDS) {
564 mmra = rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as",
565 "global");
566 }
567
568 constexpr llvm::StringLiteral scope = "workgroup";
569
570 bool emitFences = fenceGlobal || fenceLDS;
571 // Emit release fence if needed.
572 if (emitFences) {
573 auto relFence = LLVM::FenceOp::create(
574 rewriter, loc, LLVM::AtomicOrdering::release, scope);
575 if (mmra)
576 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
577 mmra);
578 }
579
580 if (chipset.majorVersion < 12) {
581 ROCDL::SBarrierOp::create(rewriter, loc);
582 } else {
583 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
584 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
585 }
586
587 if (emitFences) {
588 auto acqFence = LLVM::FenceOp::create(
589 rewriter, loc, LLVM::AtomicOrdering::acquire, scope);
590 if (mmra)
591 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
592 mmra);
593 }
594
595 rewriter.eraseOp(op);
596 return success();
597 }
598};
599
600/// Import the GPU Ops to ROCDL Patterns.
601#include "GPUToROCDL.cpp.inc"
602
603// A pass that replaces all occurrences of GPU device operations with their
604// corresponding ROCDL equivalent.
605//
606// This pass only handles device code and is not meant to be run on GPU host
607// code.
608struct LowerGpuOpsToROCDLOpsPass final
609 : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
610 using Base::Base;
611
612 void getDependentDialects(DialectRegistry &registry) const override {
613 Base::getDependentDialects(registry);
615 }
616
617 void runOnOperation() override {
618 gpu::GPUModuleOp m = getOperation();
619 MLIRContext *ctx = m.getContext();
620
621 auto llvmDataLayout = m->getAttrOfType<StringAttr>(
622 LLVM::LLVMDialect::getDataLayoutAttrName());
623 if (!llvmDataLayout) {
624 llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout);
625 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
626 }
627 // Request C wrapper emission.
628 for (auto func : m.getOps<func::FuncOp>()) {
629 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
630 UnitAttr::get(ctx));
631 }
632
633 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
634 if (failed(maybeChipset)) {
635 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
636 return signalPassFailure();
637 }
638
639 /// Customize the bitwidth used for the device side index computations.
641 ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
642 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
643 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
644 options.overrideIndexBitwidth(indexBitwidth);
645
646 if (useBarePtrCallConv) {
647 options.useBarePtrCallConv = true;
648 WalkResult canUseBarePointers =
649 m.walk([](gpu::GPUFuncOp func) -> WalkResult {
651 return WalkResult::advance();
652 return WalkResult::interrupt();
653 });
654 if (canUseBarePointers.wasInterrupted()) {
655 emitError(UnknownLoc::get(ctx),
656 "bare pointer calling convention requires all memrefs to "
657 "have static shape and use the identity map");
658 return signalPassFailure();
659 }
660 }
661
662 // Apply in-dialect lowering. In-dialect lowering will replace
663 // ops which need to be lowered further, which is not supported by a
664 // single conversion pass.
665 {
666 RewritePatternSet patterns(ctx);
668 populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
669 (void)applyPatternsGreedily(m, std::move(patterns));
670 }
671
672 LLVMTypeConverter converter(ctx, options);
674
675 RewritePatternSet llvmPatterns(ctx);
677
678 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
679 allowedDialects.end());
680 for (Dialect *dialect : ctx->getLoadedDialects()) {
681 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
682 // Empty `allowedDialectsSet` means all dialects are allowed.
683 if (!allowedDialectsSet.empty() && !allowed)
684 continue;
685
686 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
687 if (!iface) {
688 // Error out if dialect was explicily specified but doesn't implement
689 // conversion interface.
690 if (allowed) {
691 m.emitError()
692 << "dialect does not implement ConvertToLLVMPatternInterface: "
693 << dialect->getNamespace();
694 return signalPassFailure();
695 }
696 continue;
697 }
698
699 iface->populateConvertToLLVMConversionPatterns(target, converter,
700 llvmPatterns);
701 }
702
703 populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
704 *maybeChipset);
705 populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime,
706 *maybeChipset);
708 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
709 signalPassFailure();
710 auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
711 auto reqdWorkGroupSizeAttrHelper =
712 rocdlDialect->getReqdWorkGroupSizeAttrHelper();
713 auto flatWorkGroupSizeAttrHelper =
714 rocdlDialect->getFlatWorkGroupSizeAttrHelper();
715 // Manually rewrite known block size attributes so the LLVMIR translation
716 // infrastructure can pick them up.
717 m.walk([&](LLVM::LLVMFuncOp op) {
718 if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
719 auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
720 // Also set up the rocdl.flat_work_group_size attribute to prevent
721 // conflicting metadata.
722 uint32_t flatSize = 1;
723 for (uint32_t size : blockSizes.asArrayRef()) {
724 flatSize *= size;
725 }
726 StringAttr flatSizeAttr =
727 StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
728 flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
729 }
730 });
731 }
732};
733
734} // namespace
735
737 target.addIllegalOp<func::FuncOp>();
738 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
739 target.addLegalDialect<ROCDL::ROCDLDialect>();
740 target.addIllegalDialect<gpu::GPUDialect>();
741 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
742 LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
743 LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
744 // These ops are legal for f32 type.
745 target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
746 return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
747 });
748 // TODO: Remove once we support replacing non-root ops.
749 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
750}
751
753 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
758 auto *rocdlDialect =
759 converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
760 populateWithGenerated(patterns);
761 patterns.add<
762 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
763 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
764 converter, IndexKind::Block, IntrType::Id);
766 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
767 converter, IndexKind::Grid, IntrType::Id);
768 patterns.add<GPUDimOpToOcklCall<gpu::BlockDimOp>>(converter,
769 IndexKind::Block);
770 patterns.add<GPUDimOpToOcklCall<gpu::GridDimOp>>(converter, IndexKind::Grid);
771 patterns.add<GPUReturnOpLowering>(converter);
772 patterns.add<GPUFuncOpLowering>(
773 converter,
775 /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
776 /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
777 rocdlDialect->getKernelAttrHelper().getName(),
778 rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName(),
779 /*kernelClusterSizeAttributeName=*/{}});
780 if (Runtime::HIP == runtime) {
781 patterns.add<GPUPrintfOpToHIPLowering>(converter);
782 } else if (Runtime::OpenCL == runtime) {
783 // Use address space = 4 to match the OpenCL definition of printf()
784 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
785 }
786 // TODO: Add alignment for workgroup memory
787 patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
788
789 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
790 GPUSubgroupBroadcastOpToROCDL, GPUBallotOpToROCDL>(converter);
791 patterns.add<GPUSubgroupIdOpToROCDL, GPUSubgroupSizeOpToROCDL,
792 GPUBarrierOpLowering>(converter, chipset);
793
794 populateMathToROCDLConversionPatterns(converter, patterns, chipset);
795}
return success()
b getContext())
static Value getLaneId(RewriterBase &rewriter, Location loc)
static constexpr int64_t kMaxThreadsPerBlockDim
Maximum number of threads per block dimension on AMD GPUs.
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func)
Returns true if the given gpu.func can be safely called using the bare pointer calling convention.
static constexpr StringLiteral amdgcnDataLayout
static Value getKnownOrOcklDim(RewriterBase &rewriter, gpu::index_lowering::IndexKind indexKind, gpu::Dimension dim, Operation *contextOp, std::optional< uint32_t > opUpperBound)
Emits a call to an OCKL block/grid size function corresponding to indexKind with argument dim,...
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, Location loc, Value value, const LLVMTypeConverter &converter)
static llvm::ManagedStatic< PassManagerOptions > options
#define add(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:274
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType, SmallVectorImpl< Value > &result, bool permitVariablySizedScalars=false)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:495
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:594
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Runtime
Potential runtimes for AMD GPU kernels.
Definition Runtimes.h:15
LLVM::ConstantRangeAttr getIndexOpRange(Operation *op, gpu::Dimension dim, std::optional< uint32_t > opUpperBound, IndexKind indexKind, IntrType intrType, unsigned bitWidth)
Returns a ConstantRangeAttr for a GPU index op, or nullptr if no bounds are found.
std::optional< uint32_t > getKnownDimensionSizeAround(Operation *op, DimensionKind kind, Dimension dim)
Retrieve the constant bounds for a given dimension and dimension kind from the context surrounding op...
Include the generated interface declarations.
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime, amdgpu::Chipset chipset)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, std::optional< amdgpu::Chipset > chipset)
Populate the given list with patterns that convert from Math to ROCDL calls.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void configureGpuToROCDLConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to ROCDL.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
The lowering of gpu.printf to a call to HIP hostcalls.
The lowering of gpu.printf to a call to an external printf() function.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14