MLIR 23.0.0git
LowerGpuOpsToROCDLOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate ROCDLIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/Pass/Pass.h"
18
41
44
45namespace mlir {
46#define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
47#include "mlir/Conversion/Passes.h.inc"
48} // namespace mlir
49
50using namespace mlir;
51
52// Truncate or extend the result depending on the index bitwidth specified
53// by the LLVMTypeConverter options.
54static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
55 Location loc, Value value,
56 const LLVMTypeConverter &converter) {
57 int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
58 int64_t indexBitwidth = converter.getIndexTypeBitwidth();
59 auto indexBitwidthType =
60 IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
61 // TODO: use <=> in C++20.
62 if (indexBitwidth > intWidth) {
63 return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
64 }
65 if (indexBitwidth < intWidth) {
66 return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
67 }
68 return value;
69}
70
71/// Returns true if the given `gpu.func` can be safely called using the bare
72/// pointer calling convention.
73static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
74 bool canBeBare = true;
75 for (Type type : func.getArgumentTypes())
76 if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
77 canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
78 return canBeBare;
79}
80
81static Value getLaneId(RewriterBase &rewriter, Location loc) {
82 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
83 Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
84 Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
85 NamedAttribute noundef = rewriter.getNamedAttr(
86 LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
87 NamedAttribute lowRange = rewriter.getNamedAttr(
88 LLVM::LLVMDialect::getRangeAttrName(),
89 LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
90 APInt(32, 32)));
91 NamedAttribute highRange = rewriter.getNamedAttr(
92 LLVM::LLVMDialect::getRangeAttrName(),
93 LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
94 APInt(32, 64)));
95 Value mbcntLo = ROCDL::MbcntLoOp::create(
96 rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
97 /*res_attrs=*/
98 rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
99 Value laneId = ROCDL::MbcntHiOp::create(
100 rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
101 rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
102 return laneId;
103}
104
105static constexpr StringLiteral amdgcnDataLayout =
106 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
107 "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
108 "32-v32:"
109 "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
110 "64-S32-A5-G1-ni:7:8:9";
111
112namespace {
113struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
115
116 LogicalResult
117 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
118 ConversionPatternRewriter &rewriter) const override {
119 Location loc = op.getLoc();
120 MLIRContext *context = rewriter.getContext();
121 // convert to:
122 // %mlo = call noundef range(i32 0, 32)
123 // @llvm.amdgcn.mbcnt.lo(-1, 0)
124 // followed by:
125 // %lid = call noundef range(i32 0, 64)
126 // @llvm.amdgcn.mbcnt.hi(-1, %mlo)
127
128 Value laneId = getLaneId(rewriter, loc);
129 // Truncate or extend the result depending on the index bitwidth specified
130 // by the LLVMTypeConverter options.
131 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
132 if (indexBitwidth > 32) {
133 laneId = LLVM::SExtOp::create(
134 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
135 } else if (indexBitwidth < 32) {
136 laneId = LLVM::TruncOp::create(
137 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
138 }
139 rewriter.replaceOp(op, {laneId});
140 return success();
141 }
142};
143
144struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
146
147 GPUSubgroupSizeOpToROCDL(const LLVMTypeConverter &converter,
148 amdgpu::Chipset chipset)
150 chipset(chipset) {}
151
152 LogicalResult
153 matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
154 ConversionPatternRewriter &rewriter) const override {
155 LLVM::ConstantRangeAttr bounds = nullptr;
156 bool isBeforeGfx10 = chipset.majorVersion < 10;
157 if (auto upperBoundAttr = op.getUpperBoundAttr()) {
158 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
159 /*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
160 /*upper=*/op.getUpperBoundAttr().getInt() + 1);
161 }
162 Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
163 rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
164 wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
165 *getTypeConverter());
166 rewriter.replaceOp(op, {wavefrontOp});
167 return success();
168 }
169
170 const amdgpu::Chipset chipset;
171};
172
173struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
175
176 GPUSubgroupIdOpToROCDL(const LLVMTypeConverter &converter,
177 amdgpu::Chipset chipset)
178 : ConvertOpToLLVMPattern<gpu::SubgroupIdOp>(converter), chipset(chipset) {
179 }
180
181 LogicalResult
182 matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 Location loc = op.getLoc();
185 auto int32Type = rewriter.getI32Type();
186
187 Value subgroupId;
188 if (chipset.majorVersion >= 12) {
189 // For gfx12+, use the hardware wave.id register directly.
190 LLVM::ConstantRangeAttr bounds;
191 if (auto upperBoundAttr = op.getUpperBoundAttr())
192 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
193 /*bitWidth=*/32, /*lower=*/0,
194 /*upper=*/upperBoundAttr.getInt());
195 subgroupId = ROCDL::WaveId::create(rewriter, loc, int32Type, bounds);
196 } else {
197 // For older architectures, compute:
198 // subgroup_id = linearized_thread_id / subgroup_size
199 // where linearized_thread_id = tid.x + dim.x * (tid.y + dim.y * tid.z)
200 Value tidX = ROCDL::ThreadIdXOp::create(rewriter, loc, int32Type);
201 Value tidY = ROCDL::ThreadIdYOp::create(rewriter, loc, int32Type);
202 Value tidZ = ROCDL::ThreadIdZOp::create(rewriter, loc, int32Type);
203 Value dimX = ROCDL::BlockDimXOp::create(rewriter, loc, int32Type);
204 Value dimY = ROCDL::BlockDimYOp::create(rewriter, loc, int32Type);
205
206 // linearized = tid.x + dim.x * (tid.y + dim.y * tid.z)
207 // Thread IDs and dimensions are non-negative and small, so use nuw+nsw.
208 auto flags =
209 LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
210 Value dimYxTidZ =
211 LLVM::MulOp::create(rewriter, loc, int32Type, dimY, tidZ, flags);
212 Value tidYPlusDimYxTidZ =
213 LLVM::AddOp::create(rewriter, loc, int32Type, tidY, dimYxTidZ, flags);
214 Value dimXxInner = LLVM::MulOp::create(rewriter, loc, int32Type, dimX,
215 tidYPlusDimYxTidZ, flags);
216 Value linearized = LLVM::AddOp::create(rewriter, loc, int32Type, tidX,
217 dimXxInner, flags);
218
219 Value subgroupSize =
220 ROCDL::WavefrontSizeOp::create(rewriter, loc, int32Type);
221 subgroupId = LLVM::UDivOp::create(rewriter, loc, int32Type, linearized,
222 subgroupSize);
223 }
224
225 subgroupId =
226 truncOrExtToLLVMType(rewriter, loc, subgroupId, *getTypeConverter());
227 rewriter.replaceOp(op, subgroupId);
228 return success();
229 }
230
231 const amdgpu::Chipset chipset;
232};
233
234static bool isSupportedReadLaneType(Type type) {
235 // https://llvm.org/docs/AMDGPUUsage.html#llvm-ir-intrinsics
236 if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type,
237 LLVM::LLVMPointerType>(type))
238 return true;
239
240 if (auto intType = dyn_cast<IntegerType>(type))
241 return llvm::is_contained({16, 32, 64},
242 static_cast<int>(intType.getWidth()));
243
244 if (auto vecType = dyn_cast<VectorType>(type)) {
245 Type elementType = vecType.getElementType();
246 if (elementType.isInteger(32))
247 return true;
248
249 if (vecType.getNumElements() == 2 &&
250 (isa<Float16Type, BFloat16Type>(elementType) ||
251 elementType.isInteger(16)))
252 return true;
253 }
254
255 return false;
256}
257
258struct GPUSubgroupBroadcastOpToROCDL
259 : public ConvertOpToLLVMPattern<gpu::SubgroupBroadcastOp> {
261
262 LogicalResult
263 matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
264 ConversionPatternRewriter &rewriter) const override {
265 Value src = adaptor.getSrc();
266 if (isSupportedReadLaneType(src.getType())) {
267 Value result = createReadlaneOp(op, adaptor, rewriter, src);
268 rewriter.replaceOp(op, result);
269 return success();
270 }
271
272 Type i32 = rewriter.getI32Type();
273 Location loc = op.getLoc();
274 SmallVector<Value> decomposed;
275 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed,
276 /*permitVariablySizedScalars=*/true)))
277 return rewriter.notifyMatchFailure(op,
278 "Unexpected decomposition failure");
279
280 SmallVector<Value> results;
281 results.reserve(decomposed.size());
282 for (Value v : decomposed)
283 results.emplace_back(createReadlaneOp(op, adaptor, rewriter, v));
284
285 Value result = LLVM::composeValue(rewriter, loc, results, src.getType());
286 rewriter.replaceOp(op, result);
287 return success();
288 }
289
290private:
291 static Value createReadlaneOp(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
292 ConversionPatternRewriter &rewriter,
293 Value src) {
294 if (adaptor.getBroadcastType() == gpu::BroadcastType::specific_lane) {
295 return ROCDL::ReadlaneOp::create(rewriter, op.getLoc(), src.getType(),
296 src, adaptor.getLane());
297 } else { // first_active_lane
298 return ROCDL::ReadfirstlaneOp::create(rewriter, op.getLoc(),
299 src.getType(), src);
300 }
301 }
302};
303
304struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
306
307 /// Lowers a shuffle to the corresponding ROCDL ops.
308 ///
309 /// Use the `width` argument to see if src lane is participating.
310 /// If not the dstLane would be itself.
311 ///
312 /// Shuffle with DS Bpermute:
313 /// let shflMode = [xor, up, down, idx]
314 /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
315 /// 1. curLaneId = using mbcnt.lo + mbcnt.hi
316 /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width
317 /// 3. dstLane = shflMode(curLaneId, step)
318 /// 4. isActiveSrcLane = dstLane < isActiveSrcLane
319 /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId
320 /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
321 /// 7. bpermute(dwordAlignedDstLane, shfl_value).
322 ///
323 LogicalResult
324 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
325 ConversionPatternRewriter &rewriter) const override {
326 Location loc = op->getLoc();
327 Value initShflValue = adaptor.getValue();
328
329 Value srcLaneId = getLaneId(rewriter, loc);
330
331 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
332 Value width = adaptor.getWidth();
333 Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
334 Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
335 Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
336 Value widthOrZeroIfOutside =
337 LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth);
338 Value dstLane;
339
340 switch (op.getMode()) {
341 case gpu::ShuffleMode::UP:
342 dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
343 adaptor.getOffset());
344 break;
345 case gpu::ShuffleMode::DOWN:
346 dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
347 adaptor.getOffset());
348 break;
349 case gpu::ShuffleMode::XOR:
350 dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
351 adaptor.getOffset());
352 break;
353 case gpu::ShuffleMode::IDX:
354 dstLane = adaptor.getOffset();
355 break;
356 }
357 Value isActiveSrcLane = LLVM::ICmpOp::create(
358 rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
359 Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
360 dstLane, srcLaneId);
361 Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
362 Value dwordAlignedDstLane =
363 LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
364
365 SmallVector<Value> decomposed;
366 if (failed(LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type,
367 decomposed)))
368 return rewriter.notifyMatchFailure(op,
369 "failed to decompose value to i32");
370 SmallVector<Value> swizzled;
371 for (Value v : decomposed) {
372 Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
373 dwordAlignedDstLane, v);
374 swizzled.emplace_back(res);
375 }
376 Value shflValue =
377 LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
378 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
379 return success();
380 }
381};
382
383struct GPUBarrierOpLowering final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
384 GPUBarrierOpLowering(const LLVMTypeConverter &converter,
385 amdgpu::Chipset chipset)
386 : ConvertOpToLLVMPattern<gpu::BarrierOp>(converter), chipset(chipset) {}
387
388 amdgpu::Chipset chipset;
389
390 LogicalResult
391 matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor,
392 ConversionPatternRewriter &rewriter) const override {
393 Location loc = op.getLoc();
394
395 // Analyze the address_spaces attribute to determine fence behavior.
396 bool fenceGlobal = false;
397 bool fenceLDS = false;
398 std::optional<ArrayAttr> addrSpacesToFence = op.getAddressSpaces();
399
400 if (addrSpacesToFence) {
401 for (auto spaceAttr :
402 addrSpacesToFence->getAsRange<gpu::AddressSpaceAttr>()) {
403 switch (spaceAttr.getValue()) {
404 case gpu::AddressSpace::Global:
405 fenceGlobal = true;
406 break;
407 case gpu::AddressSpace::Workgroup:
408 fenceLDS = true;
409 break;
410 case gpu::AddressSpace::Private:
411 break;
412 }
413 }
414 } else {
415 // Default semantics match __syncthreads() and fence both global and LDS.
416 fenceGlobal = true;
417 fenceLDS = true;
418 }
419
420 Attribute mmra;
421 if (fenceLDS && !fenceGlobal) {
422 mmra =
423 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
424 } else if (fenceGlobal && !fenceLDS) {
425 mmra = rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as",
426 "global");
427 }
428
429 constexpr llvm::StringLiteral scope = "workgroup";
430
431 bool emitFences = fenceGlobal || fenceLDS;
432 // Emit release fence if needed.
433 if (emitFences) {
434 auto relFence = LLVM::FenceOp::create(
435 rewriter, loc, LLVM::AtomicOrdering::release, scope);
436 if (mmra)
437 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
438 mmra);
439 }
440
441 if (chipset.majorVersion < 12) {
442 ROCDL::SBarrierOp::create(rewriter, loc);
443 } else {
444 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
445 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
446 }
447
448 if (emitFences) {
449 auto acqFence = LLVM::FenceOp::create(
450 rewriter, loc, LLVM::AtomicOrdering::acquire, scope);
451 if (mmra)
452 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
453 mmra);
454 }
455
456 rewriter.eraseOp(op);
457 return success();
458 }
459};
460
461/// Import the GPU Ops to ROCDL Patterns.
462#include "GPUToROCDL.cpp.inc"
463
464// A pass that replaces all occurrences of GPU device operations with their
465// corresponding ROCDL equivalent.
466//
467// This pass only handles device code and is not meant to be run on GPU host
468// code.
469struct LowerGpuOpsToROCDLOpsPass final
470 : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
471 using Base::Base;
472
473 void getDependentDialects(DialectRegistry &registry) const override {
474 Base::getDependentDialects(registry);
476 }
477
478 void runOnOperation() override {
479 gpu::GPUModuleOp m = getOperation();
480 MLIRContext *ctx = m.getContext();
481
482 auto llvmDataLayout = m->getAttrOfType<StringAttr>(
483 LLVM::LLVMDialect::getDataLayoutAttrName());
484 if (!llvmDataLayout) {
485 llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout);
486 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
487 }
488 // Request C wrapper emission.
489 for (auto func : m.getOps<func::FuncOp>()) {
490 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
491 UnitAttr::get(ctx));
492 }
493
494 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
495 if (failed(maybeChipset)) {
496 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
497 return signalPassFailure();
498 }
499
500 /// Customize the bitwidth used for the device side index computations.
502 ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
503 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
504 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
505 options.overrideIndexBitwidth(indexBitwidth);
506
507 if (useBarePtrCallConv) {
508 options.useBarePtrCallConv = true;
509 WalkResult canUseBarePointers =
510 m.walk([](gpu::GPUFuncOp func) -> WalkResult {
512 return WalkResult::advance();
513 return WalkResult::interrupt();
514 });
515 if (canUseBarePointers.wasInterrupted()) {
516 emitError(UnknownLoc::get(ctx),
517 "bare pointer calling convention requires all memrefs to "
518 "have static shape and use the identity map");
519 return signalPassFailure();
520 }
521 }
522
523 // Apply in-dialect lowering. In-dialect lowering will replace
524 // ops which need to be lowered further, which is not supported by a
525 // single conversion pass.
526 {
527 RewritePatternSet patterns(ctx);
529 populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
530 (void)applyPatternsGreedily(m, std::move(patterns));
531 }
532
533 LLVMTypeConverter converter(ctx, options);
535
536 RewritePatternSet llvmPatterns(ctx);
538
539 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
540 allowedDialects.end());
541 for (Dialect *dialect : ctx->getLoadedDialects()) {
542 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
543 // Empty `allowedDialectsSet` means all dialects are allowed.
544 if (!allowedDialectsSet.empty() && !allowed)
545 continue;
546
547 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
548 if (!iface) {
549 // Error out if dialect was explicily specified but doesn't implement
550 // conversion interface.
551 if (allowed) {
552 m.emitError()
553 << "dialect does not implement ConvertToLLVMPatternInterface: "
554 << dialect->getNamespace();
555 return signalPassFailure();
556 }
557 continue;
558 }
559
560 iface->populateConvertToLLVMConversionPatterns(target, converter,
561 llvmPatterns);
562 }
563
564 populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
565 *maybeChipset);
566 populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime,
567 *maybeChipset);
569 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
570 signalPassFailure();
571 auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
572 auto reqdWorkGroupSizeAttrHelper =
573 rocdlDialect->getReqdWorkGroupSizeAttrHelper();
574 auto flatWorkGroupSizeAttrHelper =
575 rocdlDialect->getFlatWorkGroupSizeAttrHelper();
576 // Manually rewrite known block size attributes so the LLVMIR translation
577 // infrastructure can pick them up.
578 m.walk([&](LLVM::LLVMFuncOp op) {
579 if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
580 auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
581 // Also set up the rocdl.flat_work_group_size attribute to prevent
582 // conflicting metadata.
583 uint32_t flatSize = 1;
584 for (uint32_t size : blockSizes.asArrayRef()) {
585 flatSize *= size;
586 }
587 StringAttr flatSizeAttr =
588 StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
589 flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
590 }
591 });
592 }
593};
594
595} // namespace
596
598 target.addIllegalOp<func::FuncOp>();
599 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
600 target.addLegalDialect<ROCDL::ROCDLDialect>();
601 target.addIllegalDialect<gpu::GPUDialect>();
602 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
603 LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
604 LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
605 // These ops are legal for f32 type.
606 target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
607 return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
608 });
609 // TODO: Remove once we support replacing non-root ops.
610 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
611}
612
614 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
619 auto *rocdlDialect =
620 converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
621 populateWithGenerated(patterns);
622 patterns.add<
623 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
624 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
625 converter, IndexKind::Block, IntrType::Id);
627 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
628 converter, IndexKind::Grid, IntrType::Id);
629 patterns.add<
630 gpu::index_lowering::OpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
631 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>(
632 converter, IndexKind::Block, IntrType::Dim);
634 gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
635 converter, IndexKind::Grid, IntrType::Dim);
636 patterns.add<GPUReturnOpLowering>(converter);
637 patterns.add<GPUFuncOpLowering>(
638 converter,
640 /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
641 /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
642 rocdlDialect->getKernelAttrHelper().getName(),
643 rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName(),
644 /*kernelClusterSizeAttributeName=*/{}});
645 if (Runtime::HIP == runtime) {
646 patterns.add<GPUPrintfOpToHIPLowering>(converter);
647 } else if (Runtime::OpenCL == runtime) {
648 // Use address space = 4 to match the OpenCL definition of printf()
649 patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
650 }
651 // TODO: Add alignment for workgroup memory
652 patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
653
654 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
655 GPUSubgroupBroadcastOpToROCDL>(converter);
656 patterns.add<GPUSubgroupIdOpToROCDL, GPUSubgroupSizeOpToROCDL,
657 GPUBarrierOpLowering>(converter, chipset);
658
659 populateMathToROCDLConversionPatterns(converter, patterns, chipset);
660}
return success()
b getContext())
static Value getLaneId(RewriterBase &rewriter, Location loc)
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func)
Returns true if the given gpu.func can be safely called using the bare pointer calling convention.
static constexpr StringLiteral amdgcnDataLayout
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, Location loc, Value value, const LLVMTypeConverter &converter)
static llvm::ManagedStatic< PassManagerOptions > options
#define add(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:102
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType, SmallVectorImpl< Value > &result, bool permitVariablySizedScalars=false)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:495
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:594
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Runtime
Potential runtimes for AMD GPU kernels.
Definition Runtimes.h:15
Include the generated interface declarations.
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime, amdgpu::Chipset chipset)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, std::optional< amdgpu::Chipset > chipset)
Populate the given list with patterns that convert from Math to ROCDL calls.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void configureGpuToROCDLConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to ROCDL.
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
The lowering of gpu.printf to a call to HIP hostcalls.
The lowering of gpu.printf to a call to an external printf() function.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14