MLIR 23.0.0git
AMDGPUOps.cpp
Go to the documentation of this file.
1//===- AMDGPUOps.cpp - MLIR AMDGPU dialect operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMDGPU dialect operations, their verifiers, and
10// their canonicalizations.
11//
12//===----------------------------------------------------------------------===//
13
15
24#include "mlir/IR/Builders.h"
26#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/Matchers.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/SmallVector.h"
33
34#include <algorithm>
35#include <cstdint>
36#include <limits>
37#include <optional>
38
39using namespace mlir;
40using namespace mlir::amdgpu;
41
42/// Verifies that the number of indices matches the rank of the indexed memref,
43/// emitting an op error mentioning `indexName` on mismatch.
44template <typename OpTy>
45static LogicalResult verifyIndexCount(OpTy op, StringRef indexName,
46 MemRefType memrefType,
47 int64_t numIndices) {
48 int64_t rank = memrefType.getRank();
49 if (rank != numIndices)
50 return op.emitOpError("expected ")
51 << rank << " " << indexName << " indices, got " << numIndices;
52 return success();
53}
54
55//===----------------------------------------------------------------------===//
56// 8-bit float ops
57//===----------------------------------------------------------------------===//
58LogicalResult PackedTrunc2xFp8Op::verify() {
59 if (getExisting() && getExisting().getType() != getResult().getType())
60 return emitOpError("existing values must have same type as result");
61 return success();
62}
63
64LogicalResult PackedStochRoundFp8Op::verify() {
65 if (getExisting() && getExisting().getType() != getResult().getType())
66 return emitOpError("existing values must have same type as result");
67 return success();
68}
69
70//===----------------------------------------------------------------------===//
71// mxfp float ops
72//===----------------------------------------------------------------------===//
73LogicalResult PackedScaledTruncOp::verify() {
74 if (getExisting() && getExisting().getType() != getResult().getType())
75 return emitOpError("existing values must have same type as result");
76 return success();
77}
78
79//===----------------------------------------------------------------------===//
80// FatRawBufferCastOp
81//===----------------------------------------------------------------------===//
82
83/// Convert the type `source` to one with the same sizes and strides - and
84/// offset, unless `stripOffset` is true, in which case the offset is reset to
85/// 0, if the offset should be reset but the layout of `source` isn't either the
86/// identity layout or a strided layout, this function fails.
87static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
88 bool resetOffset) {
89 MLIRContext *ctx = source.getContext();
90 MemRefType::Builder mb(source);
92 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
93 MemRefLayoutAttrInterface layout = source.getLayout();
94 if (resetOffset && !layout.isIdentity()) {
95 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
96 if (!stridedLayout)
97 return failure();
98 MemRefLayoutAttrInterface newLayout =
99 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
100 // Special case: if resetting the offset causes the strided layout to become
101 // the identity layout, then reset to the identity layout.
102 // TODO: this'll get a lot simpler when we have the contiguous layout.
103 SmallVector<int64_t> stridesIfIdentity;
104 if (source.hasStaticShape()) {
105 stridesIfIdentity = computeSuffixProduct(source.getShape());
106 } else if (source.getRank() <= 1) {
107 stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
108 }
109 if (stridesIfIdentity == stridedLayout.getStrides()) {
110 newLayout = AffineMapAttr::get(
111 AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
112 }
113 mb.setLayout(newLayout);
114 }
115 return (MemRefType)(mb);
116}
117
118LogicalResult FatRawBufferCastOp::inferReturnTypes(
119 MLIRContext *context, std::optional<Location> location, ValueRange operands,
120 DictionaryAttr attributes, PropertyRef properties, RegionRange regions,
121 SmallVectorImpl<Type> &inferredReturnTypes) {
122 Adaptor adaptor(operands, attributes, properties, regions);
123 auto sourceType =
124 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
125 if (!sourceType)
126 return failure();
127 FailureOr<MemRefType> resultType =
128 getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
129 if (failed(resultType))
130 return failure();
131 inferredReturnTypes = SmallVector<Type>{*resultType};
132 return success();
133}
134
135FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(OpBuilder &builder,
136 int resultIndex,
137 int dim) {
138 assert(resultIndex == 0 && "FatRawBufferCastOp has a single result");
139 return memref::getMixedSize(builder, getLoc(), getSource(), dim);
140}
141
142LogicalResult FatRawBufferCastOp::verify() {
143 FailureOr<MemRefType> expectedResultType =
144 getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
145 if (failed(expectedResultType))
146 return emitOpError("source type ")
147 << getSource().getType() << " can't have its offset reset";
148 if (getResult().getType() != *expectedResultType)
149 return emitOpError("expected result type to be ")
150 << *expectedResultType << " but got " << getResult().getType();
151 return success();
152}
153
154//===----------------------------------------------------------------------===//
155// RawBuffer*Op
156//===----------------------------------------------------------------------===//
157template <typename T>
158static LogicalResult verifyRawBufferOp(T &op) {
159 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
160 bool isGlobal =
161 isGlobalMemorySpace(bufferType.getMemorySpace(), /*allowFlat=*/true);
162
163 if (!isGlobal)
164 return op.emitOpError(
165 "buffer ops must operate on a memref in global memory");
166 if (!bufferType.hasRank())
167 return op.emitOpError(
168 "cannot meaningfully buffer_store to an unranked memref");
169 return verifyIndexCount(op, "buffer", bufferType, op.getIndices().size());
170}
171
172LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
173
174LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
175
176LogicalResult RawBufferAtomicFaddOp::verify() {
177 return verifyRawBufferOp(*this);
178}
179
180LogicalResult RawBufferAtomicFmaxOp::verify() {
181 return verifyRawBufferOp(*this);
182}
183
184LogicalResult RawBufferAtomicSmaxOp::verify() {
185 return verifyRawBufferOp(*this);
186}
187
188LogicalResult RawBufferAtomicUminOp::verify() {
189 return verifyRawBufferOp(*this);
190}
191
192LogicalResult RawBufferAtomicCmpswapOp::verify() {
193 return verifyRawBufferOp(*this);
194}
195
196static std::optional<uint32_t> getConstantUint32(Value v) {
197 APInt cst;
198 if (!v.getType().isInteger(32))
199 return std::nullopt;
200 if (matchPattern(v, m_ConstantInt(&cst)))
201 return cst.getZExtValue();
202 return std::nullopt;
203}
204
205template <typename OpType>
206static bool staticallyOutOfBounds(OpType op) {
207 if (!op.getBoundsCheck())
208 return false;
209 MemRefType bufferType = op.getMemref().getType();
210 if (!bufferType.hasStaticShape())
211 return false;
212 int64_t offset;
213 SmallVector<int64_t> strides;
214 if (failed(bufferType.getStridesAndOffset(strides, offset)))
215 return false;
216 int64_t result = offset + op.getIndexOffset().value_or(0);
217 if (op.getSgprOffset()) {
218 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
219 if (!sgprOffset)
220 return false;
221 result += *sgprOffset;
222 }
223 if (strides.size() != op.getIndices().size())
224 return false;
225 int64_t indexVal = 0;
226 for (auto pair : llvm::zip(strides, op.getIndices())) {
227 int64_t stride = std::get<0>(pair);
228 Value idx = std::get<1>(pair);
229 std::optional<uint32_t> idxVal = getConstantUint32(idx);
230 if (!idxVal)
231 return false;
232 indexVal += stride * *idxVal;
233 }
234 result += indexVal;
235 if (result > std::numeric_limits<uint32_t>::max())
236 // Overflow means don't drop
237 return false;
238 return result >= bufferType.getNumElements();
239}
240
241namespace {
242template <typename OpType>
243struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
244 using OpRewritePattern<OpType>::OpRewritePattern;
245
246 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
247 if (!staticallyOutOfBounds(op))
248 return failure();
249 Type loadType = op.getResult().getType();
250 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
251 rw.getZeroAttr(loadType));
252 return success();
253 }
254};
255
256template <typename OpType>
257struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
258 using OpRewritePattern<OpType>::OpRewritePattern;
259
260 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
261 if (!staticallyOutOfBounds(op))
262 return failure();
263
264 rw.eraseOp(op);
265 return success();
266 }
267};
268} // end namespace
269
270void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
271 MLIRContext *context) {
272 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
273}
274
275void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
276 MLIRContext *context) {
277 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
278}
279
280void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
281 RewritePatternSet &results, MLIRContext *context) {
282 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicFaddOp>>(context);
283}
284
285void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
286 RewritePatternSet &results, MLIRContext *context) {
287 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicFmaxOp>>(context);
288}
289
290void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
291 RewritePatternSet &results, MLIRContext *context) {
292 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicSmaxOp>>(context);
293}
294
295void RawBufferAtomicUminOp::getCanonicalizationPatterns(
296 RewritePatternSet &results, MLIRContext *context) {
297 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicUminOp>>(context);
298}
299
300void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
301 RewritePatternSet &results, MLIRContext *context) {
302 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
303 context);
304}
305
306//===----------------------------------------------------------------------===//
307// ScaledExtPackedMatrixOp
308//===----------------------------------------------------------------------===//
309LogicalResult ScaledExtPackedMatrixOp::verify() {
310 int blockSize = getBlockSize();
311 assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
312
313 int firstScaleByte = getFirstScaleByte();
314 int firstScaleLane = getFirstScaleLane();
315 auto sourceType = cast<VectorType>(getSource().getType());
316 Type elementType = sourceType.getElementType();
317 auto floatType = cast<FloatType>(elementType);
318 unsigned bitWidth = floatType.getWidth();
319
320 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
321
322 const bool is_fp8 = bitWidth == 8;
323 const bool is_block_16 = blockSize == 16;
324
325 if (!is_fp8) {
326 if (is_block_16) {
327 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
328 return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
329 "or 1 for f4 and f6.");
330 }
331 } else {
332 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
333 return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
334 "or 2 for f4 and f6.");
335 }
336 }
337 } else {
338 if (is_block_16) {
339 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
340 ((firstScaleLane == 16) && (firstScaleByte == 2));
341 if (!is_valid) {
342 return emitOpError("blockSize of 16 can only have (firstScaleLane, "
343 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
344 }
345 }
346 }
347
348 return success();
349}
350
351//===----------------------------------------------------------------------===//
352// WMMAOp
353//===----------------------------------------------------------------------===//
354
356 IntegerAttr &m, IntegerAttr &n,
357 IntegerAttr &k) {
358 SmallVector<int64_t, 3> dimensions;
359 if (parser.parseDimensionList(dimensions, false, false))
360 return failure();
361 if (dimensions.size() != 3)
362 return parser.emitError(parser.getCurrentLocation())
363 << "expected 3 dimensions in MNK dimension list";
364
365 m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
366 n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
367 k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
368 return success();
369}
370
371LogicalResult WMMAOp::verify() {
372 auto sourceAType = cast<VectorType>(getSourceA().getType());
373 auto sourceBType = cast<VectorType>(getSourceB().getType());
374 auto destType = cast<VectorType>(getDestC().getType());
375
376 Type sourceAElemType = sourceAType.getElementType();
377 Type sourceBElemType = sourceBType.getElementType();
378 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
379 return emitOpError("source vectors have different lengths: ")
380 << sourceAType << " vs. " << sourceBType;
381 }
382
383 bool isDestFloat = destType.getElementType().isFloat();
384 bool isSrcFloat = sourceAElemType.isFloat();
385
386 if (isDestFloat && !isSrcFloat)
387 return emitOpError("expected float sources with float destination");
388 if (!isDestFloat && isSrcFloat)
389 return emitOpError("expected int sources with int destination");
390
391 if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
392 return emitOpError(
393 "source element types must match (except for fp8/bf8) but have ")
394 << sourceAType << " and " << sourceBType;
395 }
396
397 if (isSrcFloat) {
398 if (getClamp())
399 return emitOpError("clamp flag is not supported for float types");
400 if (getUnsignedA() || getUnsignedB())
401 return emitOpError("unsigned flags are not supported for float types");
402 }
403 return success();
404}
405
406//===----------------------------------------------------------------------===//
407// ScaledWMMAOp
408//===----------------------------------------------------------------------===//
409
410LogicalResult ScaledWMMAOp::verify() {
411 // Helper functions for type classification.
412 auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
413 auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
414 auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
415 auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
416 auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
417 auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
418
419 auto sourceAType = cast<VectorType>(getSourceA().getType());
420 auto sourceBType = cast<VectorType>(getSourceB().getType());
421 auto destType = cast<VectorType>(getDestC().getType());
422
423 // Validate source element types are small floats (fp4/fp6/fp8).
424 Type aElemType = sourceAType.getElementType();
425 Type bElemType = sourceBType.getElementType();
426
427 // Validate vector lengths based on dimensions.
428 int64_t m = getM();
429 int64_t aLen = sourceAType.getNumElements();
430 int64_t bLen = sourceBType.getNumElements();
431 int64_t expectedOutLen = (m == 16) ? 8 : 16;
432
433 if (destType.getNumElements() != expectedOutLen)
434 return emitOpError("expected output vector of length ")
435 << expectedOutLen << " but got " << destType.getNumElements();
436
437 if (m == 16) {
438 // For 16×16×128: both A and B must be 64 elements.
439 if (aLen != 64)
440 return emitOpError(
441 "for 16x16x128, sourceA must have 64 elements but got ")
442 << aLen;
443 if (bLen != 64)
444 return emitOpError(
445 "for 16x16x128, sourceB must have 64 elements but got ")
446 << bLen;
447 } else { // m == 32
448 // For 32×16×128: only fp4 is supported, A is 128, B is 64.
449 if (!isF4(aElemType) && !isF4(bElemType))
450 return emitOpError("32x16x128 only supports fp4 element types");
451
452 if (aLen != 128)
453 return emitOpError(
454 "for 32x16x128, sourceA must have 128 elements but got ")
455 << aLen;
456 if (bLen != 64)
457 return emitOpError(
458 "for 32x16x128, sourceB must have 64 elements but got ")
459 << bLen;
460
461 // For 32x16x128, matrix A uses all 32 lanes so a_first_scale_lane must be
462 // 0.
463 if (getAFirstScaleLane() != 0)
464 return emitOpError("for 32x16x128, a_first_scale_lane must be 0");
465 }
466
467 // Validate scale types and their compatibility with matrix element types.
468 auto scaleAType = cast<VectorType>(getScaleA().getType());
469 auto scaleBType = cast<VectorType>(getScaleB().getType());
470 Type scaleAElemType = scaleAType.getElementType();
471 Type scaleBElemType = scaleBType.getElementType();
472
473 // Validate scale element types are valid scale f8 types (E8M0FNU or E4M3FN).
474 if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
475 return emitOpError(
476 "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
477
478 // Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid.
479 if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
480 return success();
481
482 // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M3|E4M3).
483 if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
484 isF4(bElemType) && isE4M3(scaleBElemType))
485 return success();
486
487 // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M3|E4M3), Scale B (E8M0).
488 if (isF4(aElemType) && isE4M3(scaleAElemType) &&
489 (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
490 return success();
491
492 // Matrix A (F4) x Matrix B (F4) with Scale A (E4M3), Scale B (E4M3).
493 if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
494 isE4M3(scaleBElemType))
495 return success();
496
497 // No valid combination matched.
498 return emitOpError("invalid combination of matrix and scale types: ")
499 << "sourceA=" << aElemType << ", scaleA=" << scaleAElemType
500 << ", sourceB=" << bElemType << ", scaleB=" << scaleBElemType;
501}
502
503//===----------------------------------------------------------------------===//
504// MFMAOp
505//===----------------------------------------------------------------------===//
506LogicalResult MFMAOp::verify() {
507 constexpr uint32_t waveSize = 64;
509
510 Type sourceType = getSourceA().getType();
511 Type destType = getDestC().getType();
512
513 Type sourceElem = sourceType, destElem = destType;
514 uint32_t sourceLen = 1, destLen = 1;
515 if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
516 sourceLen = sourceVector.getNumElements();
517 sourceElem = sourceVector.getElementType();
518 }
519 if (auto destVector = dyn_cast<VectorType>(destType)) {
520 destLen = destVector.getNumElements();
521 destElem = destVector.getElementType();
522 }
523
524 Type sourceBType = getSourceB().getType();
525 if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
526 int64_t sourceBLen = 1;
527 Type sourceBElem = sourceBType;
528 if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
529 sourceBLen = sourceBVector.getNumElements();
530 sourceBElem = sourceBVector.getElementType();
531 }
532 if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
533 !sourceBElem.isFloat(4))
534 return emitOpError("expected both source operands to have small-float "
535 "elements if one does");
536 if (sourceLen != sourceBLen)
537 return emitOpError(
538 "expected both small-float source vectors to have the same length");
539 } else {
540 if (sourceType != sourceBType)
541 return emitOpError("expected both non-small-float source operand types "
542 "to match exactly");
543 }
544 // Normalize the wider integer types the compiler expects to i8.
545 if (sourceElem.isInteger(32)) {
546 sourceLen *= 4;
547 sourceElem = b.getI8Type();
548 }
549 if (sourceElem.isInteger(64)) {
550 sourceLen *= 8;
551 sourceElem = b.getI8Type();
552 }
553
554 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
555 if (sourceLen != numSourceElems)
556 return emitOpError("expected " + Twine(numSourceElems) +
557 " source values for this operation but got " +
558 Twine(sourceLen));
559
560 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
561 if (destLen != numDestElems)
562 return emitOpError("expected " + Twine(numDestElems) +
563 " result values for this operation but got " +
564 Twine(destLen));
565
566 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
567 return emitOpError(
568 "double-precision ops do not support permuting lanes of B");
569 if (destElem.isF64() && getCbsz() != 0)
570 return emitOpError(
571 "double-precision ops do not support permuting lanes of A");
572 if (getAbid() >= (1u << getCbsz()))
573 return emitOpError(
574 "block ID for permuting A (abid) must be below 2 ** cbsz");
575
576 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
577 return emitOpError(
578 "negation flags only available for double-precision operations");
579
580 return success();
581}
582
583//===----------------------------------------------------------------------===//
584// SparseMFMAOp
585//===----------------------------------------------------------------------===//
586
587LogicalResult SparseMFMAOp::verify() {
588 constexpr uint32_t waveSize = 64;
589
590 auto sparseType = cast<VectorType>(getSourceA().getType());
591 auto denseType = cast<VectorType>(getSourceB().getType());
592 auto destType = cast<VectorType>(getDestC().getType());
593
594 Type sparseElem = sparseType.getElementType();
595 Type denseElem = denseType.getElementType();
596 int64_t sparseLen = sparseType.getNumElements();
597 int64_t denseLen = denseType.getNumElements();
598 int64_t destLen = destType.getNumElements();
599
600 if (denseLen != 2 * sparseLen)
601 return emitOpError("expected dense source operand to have exactly double "
602 "the number of elements of the sparse source operand");
603
604 // Check that source element types are compatible.
605 // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8).
606 // For other types, element types must match exactly.
607 bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
608 if (!bothFloat8 && sparseElem != denseElem)
609 return emitOpError(
610 "expected source operands to have the same element type");
611
612 // Classify the sparse MFMA variant. The three flavors differ in CBSZ/ABID
613 // handling and in the sparse-index layout:
614 // - gfx942 16-bit: max ABID = 3, sparse idx = vector<4xi8>
615 // - gfx950 16-bit / gfx942 8-bit: max ABID = 1, sparse idx = vector<2xi16>
616 // - gfx950 8-bit: CBSZ/ABID ignored by hw, sparse idx = i32
617 uint32_t m = getM(), k = getK();
618 bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
619 bool is16BitGfx942 =
620 !is8BitSource && ((m == 16 && k == 32) || (m == 32 && k == 16));
621 bool is8BitGfx950 =
622 is8BitSource && ((m == 16 && k == 128) || (m == 32 && k == 64));
623
624 // CBSZ/ABID range check. On gfx950 8-bit the hardware always uses the first
625 // set and ignores these fields, so require zeros in IR. Otherwise ABID is
626 // only meaningful when CBSZ == 0 (when CBSZ != 0 the first set is always
627 // used and ABID is irrelevant, so the verifier accepts any value).
628 if (is8BitGfx950) {
629 if (getCbsz() != 0)
630 return emitOpError(
631 "CBSZ must be 0 for this variant (field is ignored by hardware)");
632 if (getAbid() != 0)
633 return emitOpError(
634 "ABID must be 0 for this variant (field is ignored by hardware)");
635 } else if (getCbsz() == 0) {
636 unsigned maxAbid = is16BitGfx942 ? 3u : 1u;
637 if (getAbid() > maxAbid)
638 return emitOpError("ABID must be in [0, ")
639 << maxAbid << "] for this variant";
640 }
641
642 Type sparseIdxType = getSparseIdx().getType();
643 if (is8BitGfx950) {
644 if (!sparseIdxType.isInteger(32))
645 return emitOpError("expected i32 sparse indices for this variant "
646 "(no internal set structure), but got ")
647 << sparseIdxType;
648 } else {
649 unsigned expectedIdxElems = is16BitGfx942 ? 4 : 2;
650 unsigned expectedIdxBits = is16BitGfx942 ? 8 : 16;
651 auto vecType = dyn_cast<VectorType>(sparseIdxType);
652 if (!vecType || vecType.getNumElements() != expectedIdxElems ||
653 !vecType.getElementType().isInteger(expectedIdxBits))
654 return emitOpError("expected vector<")
655 << expectedIdxElems << "xi" << expectedIdxBits
656 << "> sparse indices for this variant, but got " << sparseIdxType;
657 }
658
659 int64_t expectedSourceElems = (getM() * getK()) / waveSize;
660 if (denseLen != expectedSourceElems)
661 return emitOpError("expected " + Twine(expectedSourceElems) +
662 " source values for this operation but got " +
663 Twine(denseLen));
664
665 int64_t expectedDestElems = (getM() * getN()) / waveSize;
666 if (destLen != expectedDestElems)
667 return emitOpError("expected " + Twine(expectedDestElems) +
668 " result values for this operation but got " +
669 Twine(destLen));
670
671 return success();
672}
673
674//===----------------------------------------------------------------------===//
675// SparseWMMAOp
676//===----------------------------------------------------------------------===//
677
678LogicalResult SparseWMMAOp::verify() {
679 auto sparseType = cast<VectorType>(getSourceA().getType());
680 auto denseType = cast<VectorType>(getSourceB().getType());
681 auto destType = cast<VectorType>(getDestC().getType());
682
683 Type sparseElem = sparseType.getElementType();
684 Type denseElem = denseType.getElementType();
685 Type destElem = destType.getElementType();
686 int64_t sparseLen = sparseType.getNumElements();
687 int64_t denseLen = denseType.getNumElements();
688 int64_t destLen = destType.getNumElements();
689
690 uint32_t m = getM(), n = getN(), k = getK();
691 if ((m != 16) || (n != 16))
692 return emitOpError("expected MxN to be exactly 16x16");
693
694 const bool isWavesize64 = getWave64();
695 const bool isInt4Input = sparseElem.isInteger(4) && denseElem.isInteger(4);
696 const bool isEqualLengthAllowed = isWavesize64 && isInt4Input && k == 32;
697
698 if ((denseLen != 2 * sparseLen) && !isEqualLengthAllowed)
699 return emitOpError("expected dense source operand to have exactly double "
700 "the number of elements of the sparse source operand");
701
702 if (isEqualLengthAllowed && (denseLen != sparseLen))
703 return emitOpError("expected dense source operand to have exactly the "
704 "same the number of elements");
705
706 if (destElem.isInteger()) {
707 if (!(sparseElem.isInteger() && denseElem.isInteger())) {
708 return emitOpError("source operand and destination operands must all be "
709 "either integer or float types");
710 }
711 }
712
713 if (destElem.isFloat()) {
714 if (!(sparseElem.isFloat() && denseElem.isFloat())) {
715 return emitOpError("source operand and destination operands must all be "
716 "either integer or float types");
717 }
718 }
719
720 // Check that source element types are compatible.
721 // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8).
722 // For other types, element types must match exactly.
723 bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
724 if (!bothFloat8 && sparseElem != denseElem)
725 return emitOpError(
726 "expected source operands to have the same element type");
727
728 const int64_t waveSize = isWavesize64 ? 64 : 32;
729
730 int64_t expectedSourceElems = (getM() * getK()) / waveSize;
731 if (denseLen != expectedSourceElems)
732 return emitOpError("expected " + Twine(expectedSourceElems) +
733 " source values for this operation but got " +
734 Twine(denseLen));
735
736 int64_t expectedDestElems = (getM() * getN()) / waveSize;
737 if (destLen != expectedDestElems)
738 return emitOpError("expected " + Twine(expectedDestElems) +
739 " result values for this operation but got " +
740 Twine(destLen));
741
742 return success();
743}
744
745//===----------------------------------------------------------------------===//
746// DotOp
747//===----------------------------------------------------------------------===//
748LogicalResult DotOp::verify() {
749 Type aElem = cast<VectorType>(getSourceA().getType()).getElementType();
750 Type bElem = cast<VectorType>(getSourceB().getType()).getElementType();
751 Type dest = getDestC().getType();
752
753 bool aIsFloat8 = aElem.isFloat(8);
754 bool bIsFloat8 = bElem.isFloat(8);
755 bool aIsInteger = isa<IntegerType>(aElem);
756
757 bool bothFloat8 = aIsFloat8 && bIsFloat8;
758 if (!bothFloat8 && aElem != bElem)
759 return emitOpError(
760 "expected source operands to have the same element type");
761
762 if (aElem.isF16()) {
763 if (!dest.isF32() && !dest.isF16())
764 return emitOpError("expected f32 or f16 accumulator for f16 sources");
765 } else if (aElem.isBF16()) {
766 if (!dest.isF32() && !dest.isBF16())
767 return emitOpError("expected f32 or bf16 accumulator for bf16 sources");
768 } else if (aIsInteger) {
769 if (!dest.isInteger(32))
770 return emitOpError("expected i32 accumulator for integer sources");
771 } else if (aIsFloat8) {
772 if (!dest.isF32())
773 return emitOpError("expected f32 accumulator for fp8 sources");
774 }
775
776 if ((getUnsignedA() || getUnsignedB()) && !aIsInteger)
777 return emitOpError(
778 "unsignedA/unsignedB are only valid for integer source types");
779
780 if (aElem.isInteger(16) && getUnsignedA() != getUnsignedB())
781 return emitOpError(
782 "mixed-sign dot is not supported for 16-bit integer sources");
783
784 if (getClamp()) {
785 bool noClamp = (aElem.isF16() && dest.isF16()) ||
786 (aElem.isBF16() && dest.isBF16()) || aIsFloat8;
787 if (noClamp)
788 return emitOpError(
789 "clamp is not supported for this (source, accumulator) combination");
790 }
791
792 return success();
793}
794
795//===----------------------------------------------------------------------===//
796// DPPOp
797//===----------------------------------------------------------------------===//
798LogicalResult DPPOp::verify() {
799 DPPPerm kind = getKind();
800 Attribute permArgument = getPermArgument().value_or(Attribute{});
801
802 switch (kind) {
803
804 case DPPPerm::quad_perm: {
805 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
806 if (!quadPermAttr || quadPermAttr.size() != 4) {
807 return emitOpError("quad_perm attribute must have exactly 4 elements");
808 }
809 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
810 int32_t num = elem.getInt();
811 if (num < 0 || num > 3) {
812 return emitOpError(
813 "Each element of quad_perm must be in the range [0, 3]");
814 }
815 }
816 } break;
817
818 case DPPPerm::row_shl:
819 case DPPPerm::row_shr:
820 case DPPPerm::row_ror: {
821 if (!permArgument) {
822 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
823 "' value not specified");
824 }
825 if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
826 uint32_t attrValue = intAttr.getInt();
827 if (attrValue < 1 || attrValue > 15) {
828 return emitOpError("Attribute value must be between 1 and 15");
829 }
830 }
831 } break;
832
833 case DPPPerm::wave_shl:
834 case DPPPerm::wave_shr:
835 case DPPPerm::wave_rol:
836 case DPPPerm::wave_ror:
837 case DPPPerm::row_mirror:
838 case DPPPerm::row_half_mirror:
839 case DPPPerm::row_bcast_15:
840 case DPPPerm::row_bcast_31: {
841 if (permArgument && !isa<UnitAttr>(permArgument)) {
842 return emitOpError("Expected unit attribute for permArgument, but found "
843 "non-trivial argument");
844 }
845 break;
846 }
847 }
848 return success();
849}
850
851//===----------------------------------------------------------------------===//
852// PermlaneSwapOp
853//===----------------------------------------------------------------------===//
854LogicalResult PermlaneSwapOp::verify() {
855 unsigned rowLength = getRowLength();
856
857 if (rowLength != 16 && rowLength != 32)
858 return emitOpError("row_length attribute must either be 16 or 32.");
859
860 return success();
861}
862
863/// Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
864static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op,
865 PatternRewriter &rewriter) {
866 if (isa_and_nonnull<LDSBarrierOp>(op->getNextNode())) {
867 rewriter.eraseOp(op);
868 return success();
869 }
870 return failure();
871}
872
873void LDSBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
874 MLIRContext *context) {
876}
877
878//===----------------------------------------------------------------------===//
879// MemoryCounterWaitOp
880//===----------------------------------------------------------------------===//
881
882namespace {
883/// Fuse adjacent memory counter wait ops, taking the minimum value of the
884/// counters.
885struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
886 using Base::Base;
887
888 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
889 PatternRewriter &rewriter) const override {
890 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
891 if (!next)
892 return failure();
893
894 auto setters = {&MemoryCounterWaitOp::setLoad,
895 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
896 &MemoryCounterWaitOp::setExp,
897 &MemoryCounterWaitOp::setTensor};
898 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
899 op.getTensor()};
900 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
901 next.getExp(), next.getTensor()};
902 rewriter.modifyOpInPlace(op, [&] {
903 for (auto [setter, lhs, rhs] :
904 llvm::zip_equal(setters, lhsVals, rhsVals)) {
905 if (lhs && rhs) {
906 (op.*setter)(std::min(*lhs, *rhs));
907 } else if (lhs) {
908 (op.*setter)(*lhs);
909 } else if (rhs) {
910 (op.*setter)(*rhs);
911 }
912 }
913 });
914 rewriter.eraseOp(next);
915 return success();
916 }
917};
918} // namespace
919
920void MemoryCounterWaitOp::getCanonicalizationPatterns(
921 RewritePatternSet &results, MLIRContext *context) {
922 results.add<FuseMemoryCounterWaitOp>(context);
923}
924
925//===----------------------------------------------------------------------===//
926// GatherToLDSOp
927//===----------------------------------------------------------------------===//
928
929LogicalResult GatherToLDSOp::verify() {
930 MemRefType srcType = cast<MemRefType>(getSrc().getType());
931 MemRefType dstType = cast<MemRefType>(getDst().getType());
932
933 if (failed(
934 verifyIndexCount(*this, "source", srcType, getSrcIndices().size())) ||
935 failed(verifyIndexCount(*this, "destination", dstType,
936 getDstIndices().size())))
937 return failure();
938
939 if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
940 return emitOpError("destination type inner most dim must be contiguous");
941
942 auto elemType = srcType.getElementType();
943 // Check $src and $dst element types are the same.
944 if (elemType != dstType.getElementType())
945 return emitOpError("source and destination element types must match");
946
947 // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
948 auto transferType = getTransferType();
949 int transferSize;
950 if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
951 transferSize = vectorTransfer.getNumElements() *
952 vectorTransfer.getElementTypeBitWidth();
953 } else {
954 transferSize = transferType.getIntOrFloatBitWidth();
955 }
956 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
957 return emitOpError(
958 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
959
960 if (!isGlobalMemorySpace(srcType.getMemorySpace(), /*allowFlat=*/true) &&
961 !isFatRawBufferMemorySpace(srcType.getMemorySpace()))
962 return emitOpError(
963 "source memory address space must be global or fat raw buffer");
964
965 if (!isWorkgroupMemorySpace(dstType.getMemorySpace()))
966 return emitOpError("destination memory address space must be Workgroup");
967
968 return success();
969}
970
971namespace {
972/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
973/// information or changes layout, the cast can be skipped.
974struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
976
977 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
978 PatternRewriter &rewriter) const override {
979 bool modified = false;
980 auto foldCast = [&](OpOperand &operand) {
981 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
982 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
983 rewriter.modifyOpInPlace(gatherOp,
984 [&] { operand.assign(castOp.getSource()); });
985 modified = true;
986 }
987 }
988 };
989
990 foldCast(gatherOp.getSrcMutable());
991 foldCast(gatherOp.getDstMutable());
992
993 return success(modified);
994 }
995};
996} // namespace
997
998void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
999 MLIRContext *context) {
1000 results.add<FoldGatherToLDSOfCast>(context);
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// GlobalLoadAsyncToLDSOp
1005//===----------------------------------------------------------------------===//
1006
1007LogicalResult GlobalLoadAsyncToLDSOp::verify() {
1008 MemRefType srcType = cast<MemRefType>(getSrc().getType());
1009 MemRefType dstType = cast<MemRefType>(getDst().getType());
1010
1011 if (failed(
1012 verifyIndexCount(*this, "source", srcType, getSrcIndices().size())) ||
1013 failed(verifyIndexCount(*this, "destination", dstType,
1014 getDstIndices().size())))
1015 return failure();
1016
1017 if (srcType.getElementType() != dstType.getElementType())
1018 return emitOpError("source and destination element types must match");
1019
1020 Type transferType = getTransferType();
1021 int transferSize;
1022 if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
1023 transferSize = vectorTransfer.getNumElements() *
1024 vectorTransfer.getElementTypeBitWidth();
1025 } else {
1026 transferSize = transferType.getIntOrFloatBitWidth();
1027 }
1028 if (!llvm::is_contained({8, 32, 64, 128}, transferSize))
1029 return emitOpError("transfer type size must be 8, 32, 64, or 128 bits");
1030
1031 if (!isGlobalMemorySpace(srcType.getMemorySpace(), /*allowFlat=*/false))
1032 return emitOpError("source memory address space must be global");
1033
1034 if (!isWorkgroupMemorySpace(dstType.getMemorySpace()))
1035 return emitOpError("destination memory address space must be Workgroup");
1036
1037 return success();
1038}
1039
1040static LogicalResult
1042 PatternRewriter &rewriter) {
1043 Value mask = op.getMask();
1044 if (!mask)
1045 return failure();
1046
1047 APInt maskValue;
1048 if (!matchPattern(mask, m_ConstantInt(&maskValue)))
1049 return failure();
1050
1051 if (maskValue.isZero()) {
1052 rewriter.eraseOp(op);
1053 return success();
1054 }
1055
1056 rewriter.modifyOpInPlace(op, [&]() { op.getMaskMutable().clear(); });
1057 return success();
1058}
1059
1060void GlobalLoadAsyncToLDSOp::getCanonicalizationPatterns(
1061 RewritePatternSet &results, MLIRContext *context) {
1063}
1064
1065//===----------------------------------------------------------------------===//
1066// TransposeLoadOp
1067//===----------------------------------------------------------------------===//
1068
1069LogicalResult TransposeLoadOp::verify() {
1070 MemRefType srcType = cast<MemRefType>(getSrc().getType());
1071
1072 if (failed(
1073 verifyIndexCount(*this, "source", srcType, getSrcIndices().size())))
1074 return failure();
1075
1076 if (!isWorkgroupMemorySpace(srcType.getMemorySpace()))
1077 return emitOpError("source memory address space must be Workgroup");
1078
1079 auto transferType = cast<VectorType>(getType());
1080 size_t numElements = transferType.getNumElements();
1081 size_t elementTypeSize =
1082 transferType.getElementType().getIntOrFloatBitWidth();
1083
1084 auto emitNumElementsError = [&](StringRef expected) {
1085 return emitOpError(
1086 "Transferring type size mismatch: expected num of elements: ")
1087 << expected;
1088 };
1089
1090 switch (elementTypeSize) {
1091 case 4:
1092 case 6:
1093 if (numElements != 16)
1094 return emitNumElementsError("16");
1095 break;
1096 case 8:
1097 if (numElements != 8)
1098 return emitNumElementsError("8");
1099 break;
1100 case 16:
1101 if (numElements != 4 && numElements != 8)
1102 return emitNumElementsError("4 or 8");
1103 break;
1104 default:
1105 return emitOpError("Unsupported element type size for transpose load: ")
1106 << elementTypeSize << " bits";
1107 }
1108
1109 return success();
1110}
1111
1112//===----------------------------------------------------------------------===//
1113// GlobalTransposeLoadOp
1114//===----------------------------------------------------------------------===//
1115
1116LogicalResult GlobalTransposeLoadOp::verify() {
1117 MemRefType srcType = cast<MemRefType>(getSrc().getType());
1118
1119 if (failed(
1120 verifyIndexCount(*this, "source", srcType, getSrcIndices().size())))
1121 return failure();
1122
1123 if (!isGlobalMemorySpace(srcType.getMemorySpace(), /*allowFlat=*/false))
1124 return emitOpError("source memory address space must be Global");
1125
1126 auto resultType = cast<VectorType>(getType());
1127 size_t numElements = resultType.getNumElements();
1128 size_t elementTypeSize = resultType.getElementType().getIntOrFloatBitWidth();
1129
1130 // ElementSize -> NumElements. Chipset gating (gfx1200 vs gfx1250) is
1131 // enforced in the lowering.
1132 static const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
1133 {4, 16}, // global_load_tr4_b64 (gfx1250+)
1134 {6, 16}, // global_load_tr6_b96 (gfx1250+)
1135 {8, 8}, // global_load_tr_b64 (gfx1200+)
1136 {16, 8}, // global_load_tr_b128 (gfx1200+)
1137 };
1138
1139 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
1140 if (validNumElems == kValidLoadSizeMap.end())
1141 return emitOpError(
1142 "unsupported element type size for global transpose load: ")
1143 << elementTypeSize << " bits";
1144
1145 if (numElements != validNumElems->second)
1146 return emitOpError(
1147 "transferring type size mismatch: expected num of elements: ")
1148 << validNumElems->second;
1149
1150 return success();
1151}
1152
1153//===----------------------------------------------------------------------===//
1154// MakeDmaBaseOp
1155//===----------------------------------------------------------------------===//
1156
1157template <typename BaseOp>
1158static LogicalResult verifyBase(BaseOp op) {
1159 auto ldsType = cast<MemRefType>(op.getLds().getType());
1160 auto globalType = cast<MemRefType>(op.getGlobal().getType());
1161 if (failed(verifyIndexCount(op, "global", globalType,
1162 op.getGlobalIndices().size())) ||
1163 failed(verifyIndexCount(op, "lds", ldsType, op.getLdsIndices().size())))
1164 return failure();
1165
1166 if (!isWorkgroupMemorySpace(ldsType.getMemorySpace()))
1167 return op.emitOpError(
1168 "lds memref must have workgroup address space attribute.");
1169 if (!isGlobalMemorySpace(globalType.getMemorySpace(), /*allowFlat=*/false))
1170 return op.emitOpError(
1171 "global memref must have global address space attribute.");
1172
1173 Type elementType = ldsType.getElementType();
1174 unsigned width = elementType.getIntOrFloatBitWidth();
1175
1176 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
1177 return op.emitOpError(
1178 "element type must be 1, 2, 4, or 8 bytes long but type was ")
1179 << width << " bits long.";
1180 return success();
1181}
1182
1183LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
1184
1185//===----------------------------------------------------------------------===//
1186// MakeGatherDmaBaseOp
1187//===----------------------------------------------------------------------===//
1188
1189LogicalResult
1190TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
1191 Type elementType, Type indexType) {
1192 unsigned width = elementType.getIntOrFloatBitWidth();
1193 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
1194 return emitError()
1195 << "element type must be 1, 2, 4, or 8 bytes wide but type "
1196 << elementType << " is " << width / 8 << " bytes wide.";
1197 MLIRContext *ctx = elementType.getContext();
1198 Type i16 = IntegerType::get(ctx, 32);
1199 Type i32 = IntegerType::get(ctx, 16);
1200 if (!llvm::is_contained({i16, i32}, indexType))
1201 return emitError() << "index type must be i16 or i32 but index type is "
1202 << indexType << ".";
1203 return success();
1204}
1205
1206LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
1207
1208//===----------------------------------------------------------------------===//
1209// MakeDmaDescriptorOp
1210//===----------------------------------------------------------------------===//
1211
1212template <typename DescriptorOp>
1213static LogicalResult verifyDescriptorOp(DescriptorOp op) {
1214 ArrayRef<int64_t> globalStaticStrides = op.getGlobalStaticStrides();
1215
1216 if (globalStaticStrides.empty())
1217 return op.emitOpError("strides must not be empty.");
1218 if (globalStaticStrides.back() != 1)
1219 return op.emitOpError("strides for the innermost dimension must be 1.");
1220
1221 ArrayRef<int64_t> globalStaticSizes = op.getGlobalStaticSizes();
1222 size_t rank = globalStaticSizes.size();
1223 if (rank > 5)
1224 return op.emitOpError("tensor and tile must be at most of rank 5.");
1225 if (rank != globalStaticStrides.size())
1226 return op.emitOpError("strides and sizes must have same rank.");
1227
1228 ArrayRef<int64_t> sharedStaticSizes = op.getSharedStaticSizes();
1229 if (rank != sharedStaticSizes.size())
1230 return op.emitOpError("tensor must have same rank as tile.");
1231
1232 unsigned elementTypeWidth = op.getElementTypeWidth();
1233 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
1234 return op.emitOpError(
1235 "element type width must be 1, 2, 4 or 8 bytes, but was ")
1236 << elementTypeWidth << " bits long";
1237
1238 if (!op.getAtomicBarrierAddress() && !op.getAtomicBarrierIndices().empty())
1239 return op.emitOpError(
1240 "atomic barrier indices require an atomic barrier address");
1241
1242 if (Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
1243 auto atomicBarrierAddressType =
1244 cast<MemRefType>(atomicBarrierAddress.getType());
1245 if (failed(verifyIndexCount(op, "atomic barrier", atomicBarrierAddressType,
1246 op.getAtomicBarrierIndices().size())))
1247 return failure();
1248
1249 bool barrierInLDS =
1250 isWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace());
1251 if (!barrierInLDS)
1252 return op.emitOpError("atomic barrier address must be in LDS.");
1253 }
1254
1255 if (op.getEarlyTimeout() && !op.getWorkgroupMask())
1256 return op.emitOpError(
1257 "early timeout does not apply when workgroup_mask is not set.");
1258 return success();
1259}
1260
1261template <typename DescriptorOp, typename FoldAdaptor>
1262static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor) {
1263 SmallVector<OpFoldResult> mixedGlobalSizes(op.getMixedGlobalSizes());
1264 SmallVector<OpFoldResult> mixedGlobalStrides(op.getMixedGlobalStrides());
1265 SmallVector<OpFoldResult> mixedSharedSizes(op.getMixedSharedSizes());
1266
1267 if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true,
1268 /*onlyNonZero=*/true)) &&
1269 failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true,
1270 /*onlyNonZero=*/true)) &&
1271 failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true,
1272 /*onlyNonZero=*/true)))
1273 return nullptr;
1274
1275 SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides,
1276 dynamicSharedSizes;
1277 SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides,
1278 staticSharedSizes;
1279
1280 dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes,
1281 staticGlobalSizes);
1282 op.setGlobalStaticSizes(staticGlobalSizes);
1283 op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
1284
1285 dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides,
1286 staticGlobalStrides);
1287 op.setGlobalStaticStrides(staticGlobalStrides);
1288 op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
1289
1290 dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes,
1291 staticSharedSizes);
1292 op.setSharedStaticSizes(staticSharedSizes);
1293 op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
1294 return op.getResult();
1295}
1296
1297LogicalResult MakeDmaDescriptorOp::verify() {
1298 return verifyDescriptorOp(*this);
1299}
1300
1301OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1302 return foldDescriptorOp(*this, adaptor);
1303}
1304
1305//===----------------------------------------------------------------------===//
1306// MakeGatherDmaDescriptorOp
1307//===----------------------------------------------------------------------===//
1308
1309LogicalResult MakeGatherDmaDescriptorOp::verify() {
1310 ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
1311 size_t rank = globalStaticSizes.size();
1312 if (rank > 2)
1313 return emitOpError(
1314 "tensor and tile must be at most of rank two in gather mode.");
1316 Type elementType = cast<VectorType>(indices.getType()).getElementType();
1317 if (elementType != getBase().getType().getIndexType())
1318 return emitOpError("indices' element type must match base's element type.");
1319
1320 return verifyDescriptorOp(*this);
1321}
1322
1323OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1324 return foldDescriptorOp(*this, adaptor);
1325}
1326
1327//===----------------------------------------------------------------------===//
1328// ScaledMFMAOp
1329//===----------------------------------------------------------------------===//
1330
1331namespace {
1332/// Check if the scales input is used in other scaled mfma's while they exist.
1333/// If theyre unused then pack the scales.
1334struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
1336
1337 LogicalResult matchAndRewrite(ScaledMFMAOp op,
1338 PatternRewriter &rewriter) const override {
1339 Location loc = op.getLoc();
1340 auto setOpsel = [&op](unsigned idx, int64_t val) {
1341 switch (idx) {
1342 case 3:
1343 op.setScalesIdxA(val);
1344 break;
1345 case 4:
1346 op.setScalesIdxB(val);
1347 break;
1348 default:
1349 break;
1350 }
1351 };
1352
1353 // For every scale operand of this ScaledMFMAOp, if the scale is produced by
1354 // the extraction of a single scale from some vector, then attempt to
1355 // extract 4 values from that vector instead.
1356 //
1357 // Example: (f8 here means f8E8M0FNU)
1358 // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
1359 // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
1360 // amdgpu.scaled_mfma(%scale[0] * ...
1361 //
1362 // rewrite to:
1363 //
1364 // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
1365 // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
1366 // amdgpu.scaled_mfma(%scale[0-3] * ...
1367 //
1368 // This creates duplicate shape_casts for every use but these will be
1369 // removed in CSE.
1370 for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
1371 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
1372 if (!insertOp) {
1373 return rewriter.notifyMatchFailure(op,
1374 "defining op not a vector.insert");
1375 }
1376 // If the extracted value is not a single scalar, then it has been packed.
1377 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
1378 return rewriter.notifyMatchFailure(
1379 op, "scaled mfma operand already packed");
1380 }
1381
1382 auto extractOp =
1383 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
1384 if (!extractOp) {
1385 return rewriter.notifyMatchFailure(op,
1386 "defining op not a vector.extract");
1387 }
1388
1389 Value scaleSrc = extractOp.getOperand(0);
1390 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
1391 if (!scaleSrcType) {
1392 return rewriter.notifyMatchFailure(op, "not a vector type");
1393 }
1394
1395 // We do not handle dynamic dims yet, assume that the input is padded to
1396 // a static shape now.
1397 if (!scaleSrcType.hasStaticShape()) {
1398 return rewriter.notifyMatchFailure(op,
1399 "dynamic dims not yet supported");
1400 }
1401
1402 int64_t numElements = scaleSrcType.getNumElements();
1403 if (numElements < 4) {
1404 return rewriter.notifyMatchFailure(
1405 op, "do not pack if # of scales less than four");
1406 }
1407
1408 // Find a linearized idx using the size and offsets of the extract op.
1409 auto extractedPos = llvm::to_vector_of<int64_t>(
1410 llvm::reverse(extractOp.getStaticPosition()));
1411 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
1412 int64_t scaleSrcRank = scaleSrcType.getRank();
1413 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
1414 for (int64_t i = 1; i < scaleSrcRank; ++i) {
1415 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
1416 }
1417 int64_t idx = linearize(extractedPos, extractSizes);
1418
1419 // All n scales (where n is the total number of scales) must now be
1420 // extracted in chunks of 4 elements. This is done by dividing the
1421 // original vector of scales into groups of 4 elements
1422 // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
1423 // scale at a particular index are now replaced with an extraction
1424 // of the entire group of 4 elements to which that index belongs.
1425 //
1426 // If the number of scales happens to be indivisible by 4, extract
1427 // the remaining n - m scales in a chunk of 4 elements starting at
1428 // offset n - 4.
1429 int64_t offset = idx - (idx % 4);
1430 int64_t opsel = idx - offset;
1431 int64_t size = 4l;
1432 // Accomdate remaining elements in the case of non-4-divisible vectors.
1433 if (numElements - offset < size) {
1434 opsel = size - (numElements - idx);
1435 offset = numElements - 4l;
1436 }
1437 Type scaleSrcElemType = scaleSrcType.getElementType();
1438 auto newSrcType =
1439 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1440 Value newScaleSrc =
1441 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1442 auto extract = vector::ExtractStridedSliceOp::create(
1443 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1444 ArrayRef{int64_t(1)});
1445 rewriter.modifyOpInPlace(op, [&] {
1446 op->setOperand(opIdx, extract);
1447 setOpsel(opIdx, opsel);
1448 });
1449 }
1450 return success();
1451 }
1452};
1453} // namespace
1454
1455void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
1456 MLIRContext *context) {
1457 results.add<PackScales>(context);
1458}
1459
1460//===----------------------------------------------------------------------===//
1461// In-LDS Barrier Operations (gfx1250+)
1462//===----------------------------------------------------------------------===//
1463
1464template <typename T>
1465static LogicalResult verifyDsBarrierOpCommon(T &op) {
1466 MemRefType memrefType = llvm::cast<MemRefType>(op.getBase().getType());
1467 if (failed(
1468 verifyIndexCount(op, "barrier", memrefType, op.getIndices().size())))
1469 return failure();
1470
1471 if (!isWorkgroupMemorySpace(memrefType.getMemorySpace()))
1472 return op.emitOpError("barrier must be in workgroup (LDS) memory");
1473
1474 return success();
1475}
1476
1477LogicalResult DsBarrierInitOp::verify() {
1478 return verifyDsBarrierOpCommon(*this);
1479}
1480
1481LogicalResult DsBarrierPollStateOp::verify() {
1482 return verifyDsBarrierOpCommon(*this);
1483}
1484
1485LogicalResult DsAsyncBarrierArriveOp::verify() {
1486 return verifyDsBarrierOpCommon(*this);
1487}
1488
1489LogicalResult DsBarrierArriveOp::verify() {
1490 return verifyDsBarrierOpCommon(*this);
1491}
1492
1493//===----------------------------------------------------------------------===//
1494// GlobalPrefetchOp
1495//===----------------------------------------------------------------------===//
1496
1497LogicalResult GlobalPrefetchOp::verify() {
1498 auto src = cast<MemRefType>(getSrc().getType());
1499
1500 if (failed(verifyIndexCount(*this, "source", src, getIndices().size())))
1501 return failure();
1502
1503 Attribute memSpace = src.getMemorySpace();
1504 if (!memSpace)
1505 return this->emitOpError("the source must have address space attribute");
1506 if (!isGlobalMemorySpace(memSpace, /*allowFlat=*/false))
1507 return this->emitOpError("the source must reside in global address space");
1508
1509 const LoadTemporalHint temporalHint = getTemporalHint();
1510 const Scope scope = getCacheScope();
1511 const bool isSpeculative = getSpeculative();
1512
1513 // See GFX1250 SPG for a detail explanation
1514 if (isSpeculative && scope == Scope::WGP)
1515 return this->emitOpError(
1516 "does not support speculative prefetch in WGP scope");
1517
1518 // Note that temporal hints are shared between load, store,
1519 // prefetch, etc. instructions. However, some instructions
1520 // operate only with a subset of hints according to the ISA
1521 // documentation. In case of global prefetch, non-temporal (NT)
1522 // and last-use (LU) hints are not used. The extra bits of encoding
1523 // are used to encode speculative or non-speculative instruction behavior
1524 if (llvm::is_contained({LoadTemporalHint::NT, LoadTemporalHint::LU},
1525 temporalHint))
1526 return this->emitOpError("does not support NT and LU modes");
1527
1528 if (llvm::is_contained({LoadTemporalHint::NT_RT, LoadTemporalHint::RT_NT,
1529 LoadTemporalHint::NT_HT},
1530 temporalHint) &&
1531 !isSpeculative) {
1532 return this->emitOpError("operates only in the speculative mode");
1533 }
1534 return success();
1535}
1536
1537#define GET_OP_CLASSES
1538#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyDescriptorOp(DescriptorOp op)
static LogicalResult verifyRawBufferOp(T &op)
static LogicalResult verifyDsBarrierOpCommon(T &op)
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
Definition AMDGPUOps.cpp:87
static LogicalResult verifyIndexCount(OpTy op, StringRef indexName, MemRefType memrefType, int64_t numIndices)
Verifies that the number of indices matches the rank of the indexed memref, emitting an op error ment...
Definition AMDGPUOps.cpp:45
static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, PatternRewriter &rewriter)
Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
static LogicalResult foldGlobalLoadAsyncToLDSConstantMask(GlobalLoadAsyncToLDSOp op, PatternRewriter &rewriter)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:47
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool isGlobalMemorySpace(Attribute memorySpace, bool allowFlat)
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
bool isWorkgroupMemorySpace(Attribute memorySpace)
bool isFatRawBufferMemorySpace(Attribute memorySpace)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:70
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...