MLIR 22.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1//===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMDGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/Diagnostics.h"
25#include "mlir/IR/Matchers.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34#include <algorithm>
35#include <cstdint>
36#include <limits>
37#include <optional>
38
39using namespace mlir;
40using namespace mlir::amdgpu;
41
42#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
43
44namespace {
45struct AMDGPUInlinerInterface final : DialectInlinerInterface {
46 using DialectInlinerInterface::DialectInlinerInterface;
47 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
48 return true;
49 }
50};
51} // namespace
52
53void AMDGPUDialect::initialize() {
54 addOperations<
55#define GET_OP_LIST
56#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
57 >();
58 addTypes<
59#define GET_TYPEDEF_LIST
60#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
61 >();
62 addAttributes<
63#define GET_ATTRDEF_LIST
64#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
65 >();
66 addInterfaces<AMDGPUInlinerInterface>();
67}
68
69//===----------------------------------------------------------------------===//
70// 8-bit float ops
71//===----------------------------------------------------------------------===//
72LogicalResult PackedTrunc2xFp8Op::verify() {
73 if (getExisting() && getExisting().getType() != getResult().getType())
74 return emitOpError("existing values must have same type as result");
75 return success();
76}
77
78LogicalResult PackedStochRoundFp8Op::verify() {
79 if (getExisting() && getExisting().getType() != getResult().getType())
80 return emitOpError("existing values must have same type as result");
81 return success();
82}
83
84//===----------------------------------------------------------------------===//
85// mxfp float ops
86//===----------------------------------------------------------------------===//
87LogicalResult PackedScaledTruncOp::verify() {
88 if (getExisting() && getExisting().getType() != getResult().getType())
89 return emitOpError("existing values must have same type as result");
90 return success();
91}
92
93//===----------------------------------------------------------------------===//
94// FatRawBufferCastOp
95//===----------------------------------------------------------------------===//
96
97/// Convert the type `source` to one with the same sizes and strides - and
98/// offset, unless `stripOffset` is true, in which case the offset is reset to
99/// 0, if the offset should be reset but the layout of `source` isn't either the
100/// identity layout or a strided layout, this function fails.
101static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
102 bool resetOffset) {
103 MLIRContext *ctx = source.getContext();
104 MemRefType::Builder mb(source);
106 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
107 MemRefLayoutAttrInterface layout = source.getLayout();
108 if (resetOffset && !layout.isIdentity()) {
109 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
110 if (!stridedLayout)
111 return failure();
112 MemRefLayoutAttrInterface newLayout =
113 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
114 // Special case: if resetting the offset causes the strided layout to become
115 // the identity layout, then reset to the identity layout.
116 // TODO: this'll get a lot simpler when we have the contiguous layout.
117 SmallVector<int64_t> stridesIfIdentity;
118 if (source.hasStaticShape()) {
119 stridesIfIdentity = computeSuffixProduct(source.getShape());
120 } else if (source.getRank() <= 1) {
121 stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
122 }
123 if (stridesIfIdentity == stridedLayout.getStrides()) {
124 newLayout = AffineMapAttr::get(
125 AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
126 }
127 mb.setLayout(newLayout);
128 }
129 return (MemRefType)(mb);
130}
131
132LogicalResult FatRawBufferCastOp::inferReturnTypes(
133 MLIRContext *context, std::optional<Location> location, ValueRange operands,
134 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
135 SmallVectorImpl<Type> &inferredReturnTypes) {
136 Adaptor adaptor(operands, attributes, properties, regions);
137 auto sourceType =
138 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
139 if (!sourceType)
140 return failure();
141 FailureOr<MemRefType> resultType =
142 getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
143 if (failed(resultType))
144 return failure();
145 inferredReturnTypes = SmallVector<Type>{*resultType};
146 return success();
147}
148
149LogicalResult FatRawBufferCastOp::verify() {
150 FailureOr<MemRefType> expectedResultType =
151 getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
152 if (failed(expectedResultType))
153 return emitOpError("source type ")
154 << getSource().getType() << " can't have its offset reset";
155 if (getResult().getType() != *expectedResultType)
156 return emitOpError("expected result type to be ")
157 << *expectedResultType << " but got " << getResult().getType();
158 return success();
159}
160
161static bool hasGlobalMemorySpace(Attribute memorySpace) {
162 if (!memorySpace)
163 return true;
164 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
165 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
166 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
167 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
168 return false;
169}
170
171static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
172 if (!memorySpace)
173 return false;
174 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
175 return intMemorySpace.getInt() == 3;
176 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
177 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
178 return false;
179}
180
181static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
182 if (!memorySpace)
183 return false;
184 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
185 return intMemorySpace.getInt() == 7;
186 if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
187 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
188 return false;
189}
190
191//===----------------------------------------------------------------------===//
192// RawBuffer*Op
193//===----------------------------------------------------------------------===//
194template <typename T>
195static LogicalResult verifyRawBufferOp(T &op) {
196 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
197 bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
198
199 if (!isGlobal)
200 return op.emitOpError(
201 "Buffer ops must operate on a memref in global memory");
202 if (!bufferType.hasRank())
203 return op.emitOpError(
204 "Cannot meaningfully buffer_store to an unranked memref");
205 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
206 return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
207 " indices to memref");
208 return success();
209}
210
211LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
212
213LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
214
215LogicalResult RawBufferAtomicFaddOp::verify() {
216 return verifyRawBufferOp(*this);
217}
218
219LogicalResult RawBufferAtomicFmaxOp::verify() {
220 return verifyRawBufferOp(*this);
221}
222
223LogicalResult RawBufferAtomicSmaxOp::verify() {
224 return verifyRawBufferOp(*this);
225}
226
227LogicalResult RawBufferAtomicUminOp::verify() {
228 return verifyRawBufferOp(*this);
229}
230
231LogicalResult RawBufferAtomicCmpswapOp::verify() {
232 return verifyRawBufferOp(*this);
233}
234
235static std::optional<uint32_t> getConstantUint32(Value v) {
236 APInt cst;
237 if (!v.getType().isInteger(32))
238 return std::nullopt;
239 if (matchPattern(v, m_ConstantInt(&cst)))
240 return cst.getZExtValue();
241 return std::nullopt;
242}
243
244template <typename OpType>
245static bool staticallyOutOfBounds(OpType op) {
246 if (!op.getBoundsCheck())
247 return false;
248 MemRefType bufferType = op.getMemref().getType();
249 if (!bufferType.hasStaticShape())
250 return false;
251 int64_t offset;
252 SmallVector<int64_t> strides;
253 if (failed(bufferType.getStridesAndOffset(strides, offset)))
254 return false;
255 int64_t result = offset + op.getIndexOffset().value_or(0);
256 if (op.getSgprOffset()) {
257 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
258 if (!sgprOffset)
259 return false;
260 result += *sgprOffset;
261 }
262 if (strides.size() != op.getIndices().size())
263 return false;
264 int64_t indexVal = 0;
265 for (auto pair : llvm::zip(strides, op.getIndices())) {
266 int64_t stride = std::get<0>(pair);
267 Value idx = std::get<1>(pair);
268 std::optional<uint32_t> idxVal = getConstantUint32(idx);
269 if (!idxVal)
270 return false;
271 indexVal += stride * *idxVal;
272 }
273 result += indexVal;
274 if (result > std::numeric_limits<uint32_t>::max())
275 // Overflow means don't drop
276 return false;
277 return result >= bufferType.getNumElements();
278}
279
280namespace {
281template <typename OpType>
282struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
283 using OpRewritePattern<OpType>::OpRewritePattern;
284
285 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
286 if (!staticallyOutOfBounds(op))
287 return failure();
288 Type loadType = op.getResult().getType();
289 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
290 rw.getZeroAttr(loadType));
291 return success();
292 }
293};
294
295template <typename OpType>
296struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
297 using OpRewritePattern<OpType>::OpRewritePattern;
298
299 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
300 if (!staticallyOutOfBounds(op))
301 return failure();
302
303 rw.eraseOp(op);
304 return success();
305 }
306};
307} // end namespace
308
309void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
310 MLIRContext *context) {
311 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
312}
313
314void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
315 MLIRContext *context) {
316 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
317}
318
319void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
320 RewritePatternSet &results, MLIRContext *context) {
321 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
322}
323
324void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
325 RewritePatternSet &results, MLIRContext *context) {
326 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
327}
328
329void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
330 RewritePatternSet &results, MLIRContext *context) {
331 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
332}
333
334void RawBufferAtomicUminOp::getCanonicalizationPatterns(
335 RewritePatternSet &results, MLIRContext *context) {
336 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
337}
338
339void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
340 RewritePatternSet &results, MLIRContext *context) {
341 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
342 context);
343}
344
345//===----------------------------------------------------------------------===//
346// ScaledExtPackedMatrixOp
347//===----------------------------------------------------------------------===//
348LogicalResult ScaledExtPackedMatrixOp::verify() {
349 int blockSize = getBlockSize();
350 assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
351
352 int firstScaleByte = getFirstScaleByte();
353 int firstScaleLane = getFirstScaleLane();
354 auto sourceType = cast<VectorType>(getSource().getType());
355 Type elementType = sourceType.getElementType();
356 auto floatType = cast<FloatType>(elementType);
357 unsigned bitWidth = floatType.getWidth();
358
359 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
360
361 const bool is_fp8 = bitWidth == 8;
362 const bool is_block_16 = blockSize == 16;
363
364 if (!is_fp8) {
365 if (is_block_16) {
366 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
367 return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
368 "or 1 for f4 and f6.");
369 }
370 } else {
371 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
372 return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
373 "or 2 for f4 and f6.");
374 }
375 }
376 } else {
377 if (is_block_16) {
378 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
379 ((firstScaleLane == 16) && (firstScaleByte == 2));
380 if (!is_valid) {
381 return emitOpError("blockSize of 16 can only have (firstScaleLane, "
382 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
383 }
384 }
385 }
386
387 return success();
388}
389
390//===----------------------------------------------------------------------===//
391// WMMAOp
392//===----------------------------------------------------------------------===//
393
395 IntegerAttr &m, IntegerAttr &n,
396 IntegerAttr &k) {
397 SmallVector<int64_t, 3> dimensions;
398 if (parser.parseDimensionList(dimensions, false, false))
399 return failure();
400 if (dimensions.size() != 3)
401 return parser.emitError(parser.getCurrentLocation())
402 << "expected 3 dimensions in MNK dimension list";
403
404 m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
405 n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
406 k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
407 return success();
408}
409
410LogicalResult WMMAOp::verify() {
411 auto sourceAType = cast<VectorType>(getSourceA().getType());
412 auto sourceBType = cast<VectorType>(getSourceB().getType());
413 auto destType = cast<VectorType>(getDestC().getType());
414
415 Type sourceAElemType = sourceAType.getElementType();
416 Type sourceBElemType = sourceBType.getElementType();
417 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
418 return emitOpError("source vectors have different lengths: ")
419 << sourceAType << " vs. " << sourceBType;
420 }
421
422 bool isDestFloat = destType.getElementType().isFloat();
423 bool isSrcFloat = sourceAElemType.isFloat();
424
425 if (isDestFloat && !isSrcFloat)
426 return emitOpError("expected float sources with float destination");
427 if (!isDestFloat && isSrcFloat)
428 return emitOpError("expected int sources with int destination");
429
430 if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
431 return emitOpError(
432 "source element types must match (except for fp8/bf8) but have ")
433 << sourceAType << " and " << sourceBType;
434 }
435
436 if (isSrcFloat) {
437 if (getClamp())
438 return emitOpError("clamp flag is not supported for float types");
439 if (getUnsignedA() || getUnsignedB())
440 return emitOpError("unsigned flags are not supported for float types");
441 }
442 return success();
443}
444
445//===----------------------------------------------------------------------===//
446// MFMAOp
447//===----------------------------------------------------------------------===//
448LogicalResult MFMAOp::verify() {
449 constexpr uint32_t waveSize = 64;
451
452 Type sourceType = getSourceA().getType();
453 Type destType = getDestC().getType();
454
455 Type sourceElem = sourceType, destElem = destType;
456 uint32_t sourceLen = 1, destLen = 1;
457 if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
458 sourceLen = sourceVector.getNumElements();
459 sourceElem = sourceVector.getElementType();
460 }
461 if (auto destVector = dyn_cast<VectorType>(destType)) {
462 destLen = destVector.getNumElements();
463 destElem = destVector.getElementType();
464 }
465
466 Type sourceBType = getSourceB().getType();
467 if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
468 int64_t sourceBLen = 1;
469 Type sourceBElem = sourceBType;
470 if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
471 sourceBLen = sourceBVector.getNumElements();
472 sourceBElem = sourceBVector.getElementType();
473 }
474 if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
475 !sourceBElem.isFloat(4))
476 return emitOpError("expected both source operands to have small-float "
477 "elements if one does");
478 if (sourceLen != sourceBLen)
479 return emitOpError(
480 "expected both small-float source vectors to have the same length");
481 } else {
482 if (sourceType != sourceBType)
483 return emitOpError("expected both non-small-float source operand types "
484 "to match exactly");
485 }
486 // Normalize the wider integer types the compiler expects to i8.
487 if (sourceElem.isInteger(32)) {
488 sourceLen *= 4;
489 sourceElem = b.getI8Type();
490 }
491 if (sourceElem.isInteger(64)) {
492 sourceLen *= 8;
493 sourceElem = b.getI8Type();
494 }
495
496 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
497 if (sourceLen != numSourceElems)
498 return emitOpError("expected " + Twine(numSourceElems) +
499 " source values for this operation but got " +
500 Twine(sourceLen));
501
502 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
503 if (destLen != numDestElems)
504 return emitOpError("expected " + Twine(numDestElems) +
505 " result values for this operation but got " +
506 Twine(destLen));
507
508 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
509 return emitOpError(
510 "double-precision ops do not support permuting lanes of B");
511 if (destElem.isF64() && getCbsz() != 0)
512 return emitOpError(
513 "double-precision ops do not support permuting lanes of A");
514 if (getAbid() >= (1u << getCbsz()))
515 return emitOpError(
516 "block ID for permuting A (abid) must be below 2 ** cbsz");
517
518 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
519 return emitOpError(
520 "negation flags only available for double-precision operations");
521
522 return success();
523}
524
525//===----------------------------------------------------------------------===//
526// DPPOp
527//===----------------------------------------------------------------------===//
528LogicalResult DPPOp::verify() {
529 Type srcType = getSrc().getType();
530 if (srcType.getIntOrFloatBitWidth() > 64) {
531 return emitOpError("integer and floating point types larger than 64 bits "
532 "are not supported");
533 }
534
535 DPPPerm kind = getKind();
536 Attribute permArgument = getPermArgument().value_or(Attribute{});
537
538 switch (kind) {
539
540 case DPPPerm::quad_perm: {
541 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
542 if (!quadPermAttr || quadPermAttr.size() != 4) {
543 return emitOpError("quad_perm attribute must have exactly 4 elements");
544 }
545 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
546 int32_t num = elem.getInt();
547 if (num < 0 || num > 3) {
548 return emitOpError(
549 "Each element of quad_perm must be in the range [0, 3]");
550 }
551 }
552 } break;
553
554 case DPPPerm::row_shl:
555 case DPPPerm::row_shr:
556 case DPPPerm::row_ror: {
557 if (!permArgument) {
558 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
559 "' value not specified");
560 }
561 if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
562 uint32_t attrValue = intAttr.getInt();
563 if (attrValue < 1 || attrValue > 15) {
564 return emitOpError("Attribute value must be between 1 and 15");
565 }
566 }
567 } break;
568
569 case DPPPerm::wave_shl:
570 case DPPPerm::wave_shr:
571 case DPPPerm::wave_rol:
572 case DPPPerm::wave_ror:
573 case DPPPerm::row_mirror:
574 case DPPPerm::row_half_mirror:
575 case DPPPerm::row_bcast_15:
576 case DPPPerm::row_bcast_31: {
577 if (permArgument && !isa<UnitAttr>(permArgument)) {
578 return emitOpError("Expected unit attribute for permArgument, but found "
579 "non-trivial argument");
580 }
581 break;
582 }
583 }
584 return success();
585}
586
587//===----------------------------------------------------------------------===//
588// PermlaneSwapOp
589//===----------------------------------------------------------------------===//
590LogicalResult PermlaneSwapOp::verify() {
591 unsigned rowLength = getRowLength();
592
593 if (rowLength != 16 && rowLength != 32)
594 return emitOpError("row_length attribute must either be 16 or 32.");
595
596 return success();
597}
598
599//===----------------------------------------------------------------------===//
600// MemoryCounterWaitOp
601//===----------------------------------------------------------------------===//
602
603namespace {
604/// Fuse adjacent memory counter wait ops, taking the minimum value of the
605/// counters.
606struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
607 using Base::Base;
608
609 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
610 PatternRewriter &rewriter) const override {
611 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
612 if (!next)
613 return failure();
614
615 auto setters = {&MemoryCounterWaitOp::setLoad,
616 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
617 &MemoryCounterWaitOp::setExp,
618 &MemoryCounterWaitOp::setTensor};
619 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
620 op.getTensor()};
621 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
622 next.getExp(), next.getTensor()};
623 rewriter.modifyOpInPlace(op, [&] {
624 for (auto [setter, lhs, rhs] :
625 llvm::zip_equal(setters, lhsVals, rhsVals)) {
626 if (lhs && rhs) {
627 (op.*setter)(std::min(*lhs, *rhs));
628 } else if (lhs) {
629 (op.*setter)(*lhs);
630 } else if (rhs) {
631 (op.*setter)(*rhs);
632 }
633 }
634 });
635 rewriter.eraseOp(next);
636 return success();
637 }
638};
639} // namespace
640
641void MemoryCounterWaitOp::getCanonicalizationPatterns(
642 RewritePatternSet &results, MLIRContext *context) {
643 results.add<FuseMemoryCounterWaitOp>(context);
644}
645
646//===----------------------------------------------------------------------===//
647// GatherToLDSOp
648//===----------------------------------------------------------------------===//
649
650LogicalResult GatherToLDSOp::verify() {
651 MemRefType srcType = cast<MemRefType>(getSrc().getType());
652 MemRefType dstType = cast<MemRefType>(getDst().getType());
653
654 if (!dstType.areTrailingDimsContiguous(1))
655 return emitOpError("destination type inner most dim must be contiguous");
656
657 auto elemType = srcType.getElementType();
658 // Check $src and $dst element types are the same.
659 if (elemType != dstType.getElementType())
660 return emitOpError("source and destination element types must match");
661
662 // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
663 auto transferType = getTransferType();
664 int transferSize;
665 if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
666 transferSize = vectorTransfer.getNumElements() *
667 vectorTransfer.getElementTypeBitWidth();
668 } else {
669 transferSize = transferType.getIntOrFloatBitWidth();
670 }
671 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
672 return emitOpError(
673 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
674
675 if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
676 !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
677 return emitOpError(
678 "source memory address space must be global or fat raw buffer");
679
680 if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
681 return emitOpError("destination memory address space must be Workgroup");
682
683 return success();
684}
685
686namespace {
687/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
688/// information or changes layout, the cast can be skipped.
689struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
691
692 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
693 PatternRewriter &rewriter) const override {
694 bool modified = false;
695 auto foldCast = [&](OpOperand &operand) {
696 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
697 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
698 rewriter.modifyOpInPlace(gatherOp,
699 [&] { operand.assign(castOp.getSource()); });
700 modified = true;
701 }
702 }
703 };
704
705 foldCast(gatherOp.getSrcMutable());
706 foldCast(gatherOp.getDstMutable());
707
708 return success(modified);
709 }
710};
711} // namespace
712
713void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
714 MLIRContext *context) {
715 results.add<FoldGatherToLDSOfCast>(context);
716}
717
718//===----------------------------------------------------------------------===//
719// TransposeLoadOp
720//===----------------------------------------------------------------------===//
721
722LogicalResult TransposeLoadOp::verify() {
723 MemRefType srcType = cast<MemRefType>(getSrc().getType());
724
725 if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
726 return emitOpError("source memory address space must be Workgroup");
727
728 auto transferType = cast<VectorType>(getType());
729 size_t numElements = transferType.getNumElements();
730 size_t elementTypeSize =
731 transferType.getElementType().getIntOrFloatBitWidth();
732
733 // ElementSize -> NumElements
734 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
735 {4, 16},
736 {6, 16},
737 {8, 8},
738 {16, 4},
739 };
740
741 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
742 if (validNumElems == kValidLoadSizeMap.end())
743 return emitOpError("Unsupported element type size for transpose load: ")
744 << elementTypeSize << " bits";
745
746 if (numElements != validNumElems->second)
747 return emitOpError(
748 "Transferring type size mismatch: expected num of elements: ")
749 << validNumElems->second;
750
751 return success();
752}
753
754//===----------------------------------------------------------------------===//
755// MakeDmaBaseOp
756//===----------------------------------------------------------------------===//
757
758LogicalResult MakeDmaBaseOp::verify() {
759
760 auto ldsType = cast<MemRefType>(getLds().getType());
761 auto globalType = cast<MemRefType>(getGlobal().getType());
762 if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
763 return emitOpError(
764 "lds memref must have workgroup address space attribute.");
765 if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
766 return emitOpError(
767 "global memref must have global address space attribute.");
768
769 Type elementType = ldsType.getElementType();
770 unsigned width = elementType.getIntOrFloatBitWidth();
771
772 if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
773 return emitOpError(
774 "element type must be 1, 2, 4, or 8 bytes long but type was ")
775 << width << " bits long.";
776
777 return success();
778}
779
780//===----------------------------------------------------------------------===//
781// MakeDmaDescriptorOp
782//===----------------------------------------------------------------------===//
783
784LogicalResult MakeDmaDescriptorOp::verify() {
785 ArrayRef<int64_t> globalStaticStrides = getGlobalStaticStrides();
786
787 if (globalStaticStrides.empty())
788 return emitOpError("strides must not be empty.");
789 if (globalStaticStrides.back() != 1)
790 return emitOpError("strides for the innermost dimension must be 1.");
791
792 ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
793 size_t rank = globalStaticSizes.size();
794 if (rank > 5)
795 return emitOpError("tensor and tile must be at most of rank 5.");
796 if (rank != globalStaticStrides.size())
797 return emitOpError("strides and sizes must have same rank.");
798
799 ArrayRef<int64_t> sharedStaticSizes = getSharedStaticSizes();
800 if (rank != sharedStaticSizes.size())
801 return emitOpError("tensor must have same rank as tile.");
802
803 unsigned elementTypeWidth = getElementTypeWidth();
804 if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidth))
805 return emitOpError(
806 "element type width must be 1, 2, 4 or 8 bytes, but was ")
807 << elementTypeWidth << " bits long";
808
809 if (Value atomicBarrierAddress = getAtomicBarrierAddress()) {
810 auto atomicBarrierAddressType =
811 cast<MemRefType>(atomicBarrierAddress.getType());
812 bool barrierInLDS =
813 hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace());
814 if (!barrierInLDS)
815 return emitOpError("atomic barrier address must be in LDS.");
816 }
817
818 if (getEarlyTimeout() && !getWorkgroupMask())
819 return emitOpError(
820 "early timeout does not apply when workgroup_mask is not set.");
821 return success();
822}
823
824OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
825 SmallVector<OpFoldResult> mixedGlobalSizes(getMixedGlobalSizes());
826 SmallVector<OpFoldResult> mixedGlobalStrides(getMixedGlobalStrides());
827 SmallVector<OpFoldResult> mixedSharedSizes(getMixedSharedSizes());
828
829 if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true,
830 /*onlyNonZero=*/true)) &&
831 failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true,
832 /*onlyNonZero=*/true)) &&
833 failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true,
834 /*onlyNonZero=*/true)))
835 return nullptr;
836
837 SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides,
838 dynamicSharedSizes;
839 SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides,
840 staticSharedSizes;
841
842 dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes,
843 staticGlobalSizes);
844 setGlobalStaticSizes(staticGlobalSizes);
845 getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
846
847 dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides,
848 staticGlobalStrides);
849 setGlobalStaticStrides(staticGlobalStrides);
850 getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
851
852 dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes,
853 staticSharedSizes);
854 setSharedStaticSizes(staticSharedSizes);
855 getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
856 return getResult();
857}
858
859//===----------------------------------------------------------------------===//
860// ScaledMFMAOp
861//===----------------------------------------------------------------------===//
862
863namespace {
864/// Check if the scales input is used in other scaled mfma's while they exist.
865/// If theyre unused then pack the scales.
866struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
868
869 LogicalResult matchAndRewrite(ScaledMFMAOp op,
870 PatternRewriter &rewriter) const override {
871 Location loc = op.getLoc();
872 auto setOpsel = [&op](unsigned idx, int64_t val) {
873 switch (idx) {
874 case 3:
875 op.setScalesIdxA(val);
876 break;
877 case 4:
878 op.setScalesIdxB(val);
879 break;
880 default:
881 break;
882 }
883 };
884
885 // For every scale operand of this ScaledMFMAOp, if the scale is produced by
886 // the extraction of a single scale from some vector, then attempt to
887 // extract 4 values from that vector instead.
888 //
889 // Example: (f8 here means f8E8M0FNU)
890 // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
891 // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
892 // amdgpu.scaled_mfma(%scale[0] * ...
893 //
894 // rewrite to:
895 //
896 // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
897 // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
898 // amdgpu.scaled_mfma(%scale[0-3] * ...
899 //
900 // This creates duplicate shape_casts for every use but these will be
901 // removed in CSE.
902 for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
903 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
904 if (!insertOp) {
905 return rewriter.notifyMatchFailure(op,
906 "defining op not a vector.insert");
907 }
908 // If the extracted value is not a single scalar, then it has been packed.
909 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
910 return rewriter.notifyMatchFailure(
911 op, "scaled mfma operand already packed");
912 }
913
914 auto extractOp =
915 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
916 if (!extractOp) {
917 return rewriter.notifyMatchFailure(op,
918 "defining op not a vector.extract");
919 }
920
921 Value scaleSrc = extractOp.getOperand(0);
922 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
923 if (!scaleSrcType) {
924 return rewriter.notifyMatchFailure(op, "not a vector type");
925 }
926
927 // We do not handle dynamic dims yet, assume that the input is padded to
928 // a static shape now.
929 if (!scaleSrcType.hasStaticShape()) {
930 return rewriter.notifyMatchFailure(op,
931 "dynamic dims not yet supported");
932 }
933
934 int64_t numElements = scaleSrcType.getNumElements();
935 if (numElements <= 4) {
936 return rewriter.notifyMatchFailure(
937 op, "no packing if # of scales less than four");
938 }
939
940 // Find a linearized idx using the size and offsets of the extract op.
941 auto extractedPos = llvm::to_vector_of<int64_t>(
942 llvm::reverse(extractOp.getStaticPosition()));
943 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
944 int64_t scaleSrcRank = scaleSrcType.getRank();
945 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
946 for (int64_t i = 1; i < scaleSrcRank; ++i) {
947 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
948 }
949 int64_t idx = linearize(extractedPos, extractSizes);
950
951 // All n scales (where n is the total number of scales) must now be
952 // extracted in chunks of 4 elements. This is done by dividing the
953 // original vector of scales into groups of 4 elements
954 // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
955 // scale at a particular index are now replaced with an extraction
956 // of the entire group of 4 elements to which that index belongs.
957 //
958 // If the number of scales happens to be indivisible by 4, extract
959 // the remaining n - m scales in a chunk of 4 elements starting at
960 // offset n - 4.
961 int64_t offset = idx - (idx % 4);
962 int64_t opsel = idx - offset;
963 int64_t size = 4l;
964 // Accomdate remaining elements in the case of non-4-divisible vectors.
965 if (numElements - offset < size) {
966 opsel = size - (numElements - idx);
967 offset = numElements - 4l;
968 }
969 Type scaleSrcElemType = scaleSrcType.getElementType();
970 auto newSrcType =
971 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
972 Value newScaleSrc =
973 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
974 auto extract = vector::ExtractStridedSliceOp::create(
975 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
976 ArrayRef{int64_t(1)});
977 rewriter.modifyOpInPlace(op, [&] {
978 op->setOperand(opIdx, extract);
979 setOpsel(opIdx, opsel);
980 });
981 }
982 return success();
983 }
984};
985} // namespace
986
987void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
988 MLIRContext *context) {
989 results.add<PackScales>(context);
990}
991
992#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
993
994#define GET_ATTRDEF_CLASSES
995#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
996
997#define GET_TYPEDEF_CLASSES
998#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
999
1000#define GET_OP_CLASSES
1001#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:45
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...