MLIR 23.0.0git
LowerGpuOpsToROCDLOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate ROCDLIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/Pass/Pass.h"
18
41
44
45namespace mlir {
46#define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
47#include "mlir/Conversion/Passes.h.inc"
48} // namespace mlir
49
50using namespace mlir;
51
52// Truncate or extend the result depending on the index bitwidth specified
53// by the LLVMTypeConverter options.
54static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
55 Location loc, Value value,
56 const LLVMTypeConverter &converter) {
57 int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
58 int64_t indexBitwidth = converter.getIndexTypeBitwidth();
59 auto indexBitwidthType =
60 IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
61 // TODO: use <=> in C++20.
62 if (indexBitwidth > intWidth) {
63 return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
64 }
65 if (indexBitwidth < intWidth) {
66 return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
67 }
68 return value;
69}
70
71/// Returns true if the given `gpu.func` can be safely called using the bare
72/// pointer calling convention.
73static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
74 bool canBeBare = true;
75 for (Type type : func.getArgumentTypes())
76 if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
77 canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
78 return canBeBare;
79}
80
81static Value getLaneId(RewriterBase &rewriter, Location loc) {
82 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
83 Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
84 Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
85 NamedAttribute noundef = rewriter.getNamedAttr(
86 LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
87 NamedAttribute lowRange = rewriter.getNamedAttr(
88 LLVM::LLVMDialect::getRangeAttrName(),
89 LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
90 APInt(32, 32)));
91 NamedAttribute highRange = rewriter.getNamedAttr(
92 LLVM::LLVMDialect::getRangeAttrName(),
93 LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
94 APInt(32, 64)));
95 Value mbcntLo = ROCDL::MbcntLoOp::create(
96 rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
97 /*res_attrs=*/
98 rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
99 Value laneId = ROCDL::MbcntHiOp::create(
100 rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
101 rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
102 return laneId;
103}
104
105static constexpr StringLiteral amdgcnDataLayout =
106 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
107 "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
108 "32-v32:"
109 "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
110 "64-S32-A5-G1-ni:7:8:9";
111
112namespace {
113struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
115
116 LogicalResult
117 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
118 ConversionPatternRewriter &rewriter) const override {
119 Location loc = op.getLoc();
120 MLIRContext *context = rewriter.getContext();
121 // convert to:
122 // %mlo = call noundef range(i32 0, 32)
123 // @llvm.amdgcn.mbcnt.lo(-1, 0)
124 // followed by:
125 // %lid = call noundef range(i32 0, 64)
126 // @llvm.amdgcn.mbcnt.hi(-1, %mlo)
127
128 Value laneId = getLaneId(rewriter, loc);
129 // Truncate or extend the result depending on the index bitwidth specified
130 // by the LLVMTypeConverter options.
131 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
132 if (indexBitwidth > 32) {
133 laneId = LLVM::SExtOp::create(
134 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
135 } else if (indexBitwidth < 32) {
136 laneId = LLVM::TruncOp::create(
137 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
138 }
139 rewriter.replaceOp(op, {laneId});
140 return success();
141 }
142};
143
144struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
146
147 GPUSubgroupSizeOpToROCDL(const LLVMTypeConverter &converter,
148 amdgpu::Chipset chipset)
150 chipset(chipset) {}
151
152 LogicalResult
153 matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
154 ConversionPatternRewriter &rewriter) const override {
155 LLVM::ConstantRangeAttr bounds = nullptr;
156 bool isBeforeGfx10 = chipset.majorVersion < 10;
157 if (auto upperBoundAttr = op.getUpperBoundAttr()) {
158 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
159 /*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
160 /*upper=*/op.getUpperBoundAttr().getInt() + 1);
161 }
162 Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
163 rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
164 wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
165 *getTypeConverter());
166 rewriter.replaceOp(op, {wavefrontOp});
167 return success();
168 }
169
170 const amdgpu::Chipset chipset;
171};
172
173struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
175
176 GPUSubgroupIdOpToROCDL(const LLVMTypeConverter &converter,
177 amdgpu::Chipset chipset)
178 : ConvertOpToLLVMPattern<gpu::SubgroupIdOp>(converter), chipset(chipset) {
179 }
180
181 LogicalResult
182 matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 Location loc = op.getLoc();
185 auto int32Type = rewriter.getI32Type();
186
187 Value subgroupId;
188 if (chipset.majorVersion >= 12) {
189 // For gfx12+, use the hardware wave.id register directly.
190 LLVM::ConstantRangeAttr bounds;
191 if (auto upperBoundAttr = op.getUpperBoundAttr())
192 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
193 /*bitWidth=*/32, /*lower=*/0,
194 /*upper=*/upperBoundAttr.getInt());
195 subgroupId = ROCDL::WaveId::create(rewriter, loc, int32Type, bounds);
196 } else {
197 // For older architectures, compute:
198 // subgroup_id = linearized_thread_id / subgroup_size
199 // where linearized_thread_id = tid.x + dim.x * (tid.y + dim.y * tid.z)
200 Value tidX = ROCDL::ThreadIdXOp::create(rewriter, loc, int32Type);
201 Value tidY = ROCDL::ThreadIdYOp::create(rewriter, loc, int32Type);
202 Value tidZ = ROCDL::ThreadIdZOp::create(rewriter, loc, int32Type);
203 Value dimX = ROCDL::BlockDimXOp::create(rewriter, loc, int32Type);
204 Value dimY = ROCDL::BlockDimYOp::create(rewriter, loc, int32Type);
205
206 // linearized = tid.x + dim.x * (tid.y + dim.y * tid.z)
207 // Thread IDs and dimensions are non-negative and small, so use nuw+nsw.
208 auto flags =
209 LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
210 Value dimYxTidZ =
211 LLVM::MulOp::create(rewriter, loc, int32Type, dimY, tidZ, flags);
212 Value tidYPlusDimYxTidZ =
213 LLVM::AddOp::create(rewriter, loc, int32Type, tidY, dimYxTidZ, flags);
214 Value dimXxInner = LLVM::MulOp::create(rewriter, loc, int32Type, dimX,
215 tidYPlusDimYxTidZ, flags);
216 Value linearized = LLVM::AddOp::create(rewriter, loc, int32Type, tidX,
217 dimXxInner, flags);
218
219 Value subgroupSize =
220 ROCDL::WavefrontSizeOp::create(rewriter, loc, int32Type);
221 subgroupId = LLVM::UDivOp::create(rewriter, loc, int32Type, linearized,
222 subgroupSize);
223 }
224
225 subgroupId =
226 truncOrExtToLLVMType(rewriter, loc, subgroupId, *getTypeConverter());
227 rewriter.replaceOp(op, subgroupId);
228 return success();
229 }
230
231 const amdgpu::Chipset chipset;
232};
233
234static bool isSupportedReadLaneType(Type type) {
235 // https://llvm.org/docs/AMDGPUUsage.html#llvm-ir-intrinsics
236 if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type,
237 LLVM::LLVMPointerType>(type))
238 return true;
239
240 if (auto intType = dyn_cast<IntegerType>(type))
241 return llvm::is_contained({16, 32, 64},
242 static_cast<int>(intType.getWidth()));
243
244 if (auto vecType = dyn_cast<VectorType>(type)) {
245 Type elementType = vecType.getElementType();
246 if (elementType.isInteger(32))
247 return true;
248
249 if (vecType.getNumElements() == 2 &&
250 (isa<Float16Type, BFloat16Type>(elementType) ||
251 elementType.isInteger(16)))
252 return true;
253 }
254
255 return false;
256}
257
258struct GPUSubgroupBroadcastOpToROCDL
259 : public ConvertOpToLLVMPattern<gpu::SubgroupBroadcastOp> {
261
262 LogicalResult
263 matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
264 ConversionPatternRewriter &rewriter) const override {
265 Value src = adaptor.getSrc();
266 if (isSupportedReadLaneType(src.getType())) {
267 Value result = createReadlaneOp(op, adaptor, rewriter, src);
268 rewriter.replaceOp(op, result);
269 return success();
270 }
271
272 Type i32 = rewriter.getI32Type();
273 Location loc = op.getLoc();
274 SmallVector<Value> decomposed =
275 LLVM::decomposeValue(rewriter, loc, src, i32);
276
277 SmallVector<Value> results;
278 results.reserve(decomposed.size());
279 for (Value v : decomposed)
280 results.emplace_back(createReadlaneOp(op, adaptor, rewriter, v));
281
282 Value result = LLVM::composeValue(rewriter, loc, results, src.getType());
283 rewriter.replaceOp(op, result);
284 return success();
285 }
286
287private:
288 static Value createReadlaneOp(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter,
290 Value src) {
291 if (adaptor.getBroadcastType() == gpu::BroadcastType::specific_lane) {
292 return ROCDL::ReadlaneOp::create(rewriter, op.getLoc(), src.getType(),
293 src, adaptor.getLane());
294 } else { // first_active_lane
295 return ROCDL::ReadfirstlaneOp::create(rewriter, op.getLoc(),
296 src.getType(), src);
297 }
298 }
299};
300
301struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
303
304 /// Lowers a shuffle to the corresponding ROCDL ops.
305 ///
306 /// Use the `width` argument to see if src lane is participating.
307 /// If not the dstLane would be itself.
308 ///
309 /// Shuffle with DS Bpermute:
310 /// let shflMode = [xor, up, down, idx]
311 /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
312 /// 1. curLaneId = using mbcnt.lo + mbcnt.hi
313 /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width
314 /// 3. dstLane = shflMode(curLaneId, step)
315 /// 4. isActiveSrcLane = dstLane < isActiveSrcLane
316 /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId
317 /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
318 /// 7. bpermute(dwordAlignedDstLane, shfl_value).
319 ///
320 LogicalResult
321 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter) const override {
323 Location loc = op->getLoc();
324 Value initShflValue = adaptor.getValue();
325
326 Value srcLaneId = getLaneId(rewriter, loc);
327
328 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
329 Value width = adaptor.getWidth();
330 Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
331 Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
332 Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
333 Value widthOrZeroIfOutside =
334 LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth);
335 Value dstLane;
336
337 switch (op.getMode()) {
338 case gpu::ShuffleMode::UP:
339 dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
340 adaptor.getOffset());
341 break;
342 case gpu::ShuffleMode::DOWN:
343 dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
344 adaptor.getOffset());
345 break;
346 case gpu::ShuffleMode::XOR:
347 dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
348 adaptor.getOffset());
349 break;
350 case gpu::ShuffleMode::IDX:
351 dstLane = adaptor.getOffset();
352 break;
353 }
354 Value isActiveSrcLane = LLVM::ICmpOp::create(
355 rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
356 Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
357 dstLane, srcLaneId);
358 Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
359 Value dwordAlignedDstLane =
360 LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
361
362 SmallVector<Value> decomposed =
363 LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
364 SmallVector<Value> swizzled;
365 for (Value v : decomposed) {
366 Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
367 dwordAlignedDstLane, v);
368 swizzled.emplace_back(res);
369 }
370 Value shflValue =
371 LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
372 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
373 return success();
374 }
375};
376
377struct GPUBarrierOpLowering final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
378 GPUBarrierOpLowering(const LLVMTypeConverter &converter,
379 amdgpu::Chipset chipset)
380 : ConvertOpToLLVMPattern<gpu::BarrierOp>(converter), chipset(chipset) {}
381
382 amdgpu::Chipset chipset;
383
384 LogicalResult
385 matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor,
386 ConversionPatternRewriter &rewriter) const override {
387 Location loc = op.getLoc();
388
389 // Analyze the address_spaces attribute to determine fence behavior.
390 bool fenceGlobal = false;
391 bool fenceLDS = false;
392 std::optional<ArrayAttr> addrSpacesToFence = op.getAddressSpaces();
393
394 if (addrSpacesToFence) {
395 for (auto spaceAttr :
396 addrSpacesToFence->getAsRange<gpu::AddressSpaceAttr>()) {
397 switch (spaceAttr.getValue()) {
398 case gpu::AddressSpace::Global:
399 fenceGlobal = true;
400 break;
401 case gpu::AddressSpace::Workgroup:
402 fenceLDS = true;
403 break;
404 case gpu::AddressSpace::Private:
405 break;
406 }
407 }
408 } else {
409 // Default semantics match __syncthreads() and fence both global and LDS.
410 fenceGlobal = true;
411 fenceLDS = true;
412 }
413
414 Attribute mmra;
415 if (fenceLDS && !fenceGlobal) {
416 mmra =
417 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
418 } else if (fenceGlobal && !fenceLDS) {
419 mmra = rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as",
420 "global");
421 }
422
423 constexpr llvm::StringLiteral scope = "workgroup";
424
425 bool emitFences = fenceGlobal || fenceLDS;
426 // Emit release fence if needed.
427 if (emitFences) {
428 auto relFence = LLVM::FenceOp::create(
429 rewriter, loc, LLVM::AtomicOrdering::release, scope);
430 if (mmra)
431 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
432 mmra);
433 }
434
435 if (chipset.majorVersion < 12) {
436 ROCDL::SBarrierOp::create(rewriter, loc);
437 } else {
438 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
439 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
440 }
441
442 if (emitFences) {
443 auto acqFence = LLVM::FenceOp::create(
444 rewriter, loc, LLVM::AtomicOrdering::acquire, scope);
445 if (mmra)
446 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
447 mmra);
448 }
449
450 rewriter.eraseOp(op);
451 return success();
452 }
453};
454
455/// Import the GPU Ops to ROCDL Patterns.
456#include "GPUToROCDL.cpp.inc"
457
458// A pass that replaces all occurrences of GPU device operations with their
459// corresponding ROCDL equivalent.
460//
461// This pass only handles device code and is not meant to be run on GPU host
462// code.
463struct LowerGpuOpsToROCDLOpsPass final
464 : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
465 using Base::Base;
466
467 void getDependentDialects(DialectRegistry &registry) const override {
468 Base::getDependentDialects(registry);
470 }
471
472 void runOnOperation() override {
473 gpu::GPUModuleOp m = getOperation();
474 MLIRContext *ctx = m.getContext();
475
476 auto llvmDataLayout = m->getAttrOfType<StringAttr>(
477 LLVM::LLVMDialect::getDataLayoutAttrName());
478 if (!llvmDataLayout) {
479 llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout);
480 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
481 }
482 // Request C wrapper emission.
483 for (auto func : m.getOps<func::FuncOp>()) {
484 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
485 UnitAttr::get(ctx));
486 }
487
488 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
489 if (failed(maybeChipset)) {
490 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
491 return signalPassFailure();
492 }
493
494 /// Customize the bitwidth used for the device side index computations.
496 ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
497 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
498 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
499 options.overrideIndexBitwidth(indexBitwidth);
500
501 if (useBarePtrCallConv) {
502 options.useBarePtrCallConv = true;
503 WalkResult canUseBarePointers =
504 m.walk([](gpu::GPUFuncOp func) -> WalkResult {
506 return WalkResult::advance();
507 return WalkResult::interrupt();
508 });
509 if (canUseBarePointers.wasInterrupted()) {
510 emitError(UnknownLoc::get(ctx),
511 "bare pointer calling convention requires all memrefs to "
512 "have static shape and use the identity map");
513 return signalPassFailure();
514 }
515 }
516
517 // Apply in-dialect lowering. In-dialect lowering will replace
518 // ops which need to be lowered further, which is not supported by a
519 // single conversion pass.
520 {
524 (void)applyPatternsGreedily(m, std::move(patterns));
525 }
526
527 LLVMTypeConverter converter(ctx, options);
529
530 RewritePatternSet llvmPatterns(ctx);
532
533 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
534 allowedDialects.end());
535 for (Dialect *dialect : ctx->getLoadedDialects()) {
536 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
537 // Empty `allowedDialectsSet` means all dialects are allowed.
538 if (!allowedDialectsSet.empty() && !allowed)
539 continue;
540
541 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
542 if (!iface) {
543 // Error out if dialect was explicily specified but doesn't implement
544 // conversion interface.
545 if (allowed) {
546 m.emitError()
547 << "dialect does not implement ConvertToLLVMPatternInterface: "
548 << dialect->getNamespace();
549 return signalPassFailure();
550 }
551 continue;
552 }
553
554 iface->populateConvertToLLVMConversionPatterns(target, converter,
555 llvmPatterns);
556 }
557
558 populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
559 *maybeChipset);
560 populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime,
561 *maybeChipset);
563 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
564 signalPassFailure();
565 auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
566 auto reqdWorkGroupSizeAttrHelper =
567 rocdlDialect->getReqdWorkGroupSizeAttrHelper();
568 auto flatWorkGroupSizeAttrHelper =
569 rocdlDialect->getFlatWorkGroupSizeAttrHelper();
570 // Manually rewrite known block size attributes so the LLVMIR translation
571 // infrastructure can pick them up.
572 m.walk([&](LLVM::LLVMFuncOp op) {
573 if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
574 auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
575 // Also set up the rocdl.flat_work_group_size attribute to prevent
576 // conflicting metadata.
577 uint32_t flatSize = 1;
578 for (uint32_t size : blockSizes.asArrayRef()) {
579 flatSize *= size;
580 }
581 StringAttr flatSizeAttr =
582 StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
583 flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
584 }
585 });
586 }
587};
588
589} // namespace
590
592 target.addIllegalOp<func::FuncOp>();
593 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
594 target.addLegalDialect<ROCDL::ROCDLDialect>();
595 target.addIllegalDialect<gpu::GPUDialect>();
596 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
597 LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
598 LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
599 // These ops are legal for f32 type.
600 target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
601 return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
602 });
603 // TODO: Remove once we support replacing non-root ops.
604 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
605}
606
613 auto *rocdlDialect =
614 converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
615 populateWithGenerated(patterns);
616 patterns.add<
617 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
618 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
619 converter, IndexKind::Block, IntrType::Id);
621 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
622 converter, IndexKind::Grid, IntrType::Id);
623 patterns.add<
624 gpu::index_lowering::OpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
625 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>(
626 converter, IndexKind::Block, IntrType::Dim);
628 gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
629 converter, IndexKind::Grid, IntrType::Dim);
630 patterns.add<GPUReturnOpLowering>(converter);
632 converter,
634 /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
635 /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
636 rocdlDialect->getKernelAttrHelper().getName(),
637 rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName(),
638 /*kernelClusterSizeAttributeName=*/{}});
639 if (Runtime::HIP == runtime) {
640 patterns.add<GPUPrintfOpToHIPLowering>(converter);
641 } else if (Runtime::OpenCL == runtime) {
642 // Use address space = 4 to match the OpenCL definition of printf()
643 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
644 }
645 // TODO: Add alignment for workgroup memory
647
648 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
649 GPUSubgroupBroadcastOpToROCDL>(converter);
650 patterns.add<GPUSubgroupIdOpToROCDL, GPUSubgroupSizeOpToROCDL,
651 GPUBarrierOpLowering>(converter, chipset);
652
654}
return success()
b getContext())
static Value getLaneId(RewriterBase &rewriter, Location loc)
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func)
Returns true if the given gpu.func can be safely called using the bare pointer calling convention.
static constexpr StringLiteral amdgcnDataLayout
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, Location loc, Value value, const LLVMTypeConverter &converter)
static llvm::ManagedStatic< PassManagerOptions > options
#define add(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:102
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:432
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:393
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Runtime
Potential runtimes for AMD GPU kernels.
Definition Runtimes.h:15
Include the generated interface declarations.
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime, amdgpu::Chipset chipset)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, std::optional< amdgpu::Chipset > chipset)
Populate the given list with patterns that convert from Math to ROCDL calls.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void configureGpuToROCDLConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to ROCDL.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
The lowering of gpu.printf to a call to HIP hostcalls.
The lowering of gpu.printf to a call to an external printf() function.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14