MLIR 23.0.0git
AMDGPUOps.cpp
Go to the documentation of this file.
1//===- AMDGPUOps.cpp - MLIR AMDGPU dialect operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMDGPU dialect operations, their verifiers, and
10// their canonicalizations.
11//
12//===----------------------------------------------------------------------===//
13
15
23#include "mlir/IR/Builders.h"
25#include "mlir/IR/Diagnostics.h"
26#include "mlir/IR/Matchers.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallVector.h"
32
33#include <algorithm>
34#include <cstdint>
35#include <limits>
36#include <optional>
37
38using namespace mlir;
39using namespace mlir::amdgpu;
40
41/// Verifies that the number of indices matches the rank of the indexed memref,
42/// emitting an op error mentioning `indexName` on mismatch.
43template <typename OpTy>
44static LogicalResult verifyIndexCount(OpTy op, StringRef indexName,
45 MemRefType memrefType,
46 int64_t numIndices) {
47 int64_t rank = memrefType.getRank();
48 if (rank != numIndices)
49 return op.emitOpError("expected ")
50 << rank << " " << indexName << " indices, got " << numIndices;
51 return success();
52}
53
54//===----------------------------------------------------------------------===//
55// 8-bit float ops
56//===----------------------------------------------------------------------===//
57LogicalResult PackedTrunc2xFp8Op::verify() {
58 if (getExisting() && getExisting().getType() != getResult().getType())
59 return emitOpError("existing values must have same type as result");
60 return success();
61}
62
63LogicalResult PackedStochRoundFp8Op::verify() {
64 if (getExisting() && getExisting().getType() != getResult().getType())
65 return emitOpError("existing values must have same type as result");
66 return success();
67}
68
69//===----------------------------------------------------------------------===//
70// mxfp float ops
71//===----------------------------------------------------------------------===//
72LogicalResult PackedScaledTruncOp::verify() {
73 if (getExisting() && getExisting().getType() != getResult().getType())
74 return emitOpError("existing values must have same type as result");
75 return success();
76}
77
78//===----------------------------------------------------------------------===//
79// FatRawBufferCastOp
80//===----------------------------------------------------------------------===//
81
82/// Convert the type `source` to one with the same sizes and strides - and
83/// offset, unless `stripOffset` is true, in which case the offset is reset to
84/// 0, if the offset should be reset but the layout of `source` isn't either the
85/// identity layout or a strided layout, this function fails.
86static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
87 bool resetOffset) {
88 MLIRContext *ctx = source.getContext();
89 MemRefType::Builder mb(source);
91 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
92 MemRefLayoutAttrInterface layout = source.getLayout();
93 if (resetOffset && !layout.isIdentity()) {
94 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
95 if (!stridedLayout)
96 return failure();
97 MemRefLayoutAttrInterface newLayout =
98 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
99 // Special case: if resetting the offset causes the strided layout to become
100 // the identity layout, then reset to the identity layout.
101 // TODO: this'll get a lot simpler when we have the contiguous layout.
102 SmallVector<int64_t> stridesIfIdentity;
103 if (source.hasStaticShape()) {
104 stridesIfIdentity = computeSuffixProduct(source.getShape());
105 } else if (source.getRank() <= 1) {
106 stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
107 }
108 if (stridesIfIdentity == stridedLayout.getStrides()) {
109 newLayout = AffineMapAttr::get(
110 AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
111 }
112 mb.setLayout(newLayout);
113 }
114 return (MemRefType)(mb);
115}
116
117LogicalResult FatRawBufferCastOp::inferReturnTypes(
118 MLIRContext *context, std::optional<Location> location, ValueRange operands,
119 DictionaryAttr attributes, PropertyRef properties, RegionRange regions,
120 SmallVectorImpl<Type> &inferredReturnTypes) {
121 Adaptor adaptor(operands, attributes, properties, regions);
122 auto sourceType =
123 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
124 if (!sourceType)
125 return failure();
126 FailureOr<MemRefType> resultType =
127 getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
128 if (failed(resultType))
129 return failure();
130 inferredReturnTypes = SmallVector<Type>{*resultType};
131 return success();
132}
133
134FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(OpBuilder &builder,
135 int resultIndex,
136 int dim) {
137 assert(resultIndex == 0 && "FatRawBufferCastOp has a single result");
138 return memref::getMixedSize(builder, getLoc(), getSource(), dim);
139}
140
141LogicalResult FatRawBufferCastOp::verify() {
142 FailureOr<MemRefType> expectedResultType =
143 getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
144 if (failed(expectedResultType))
145 return emitOpError("source type ")
146 << getSource().getType() << " can't have its offset reset";
147 if (getResult().getType() != *expectedResultType)
148 return emitOpError("expected result type to be ")
149 << *expectedResultType << " but got " << getResult().getType();
150 return success();
151}
152
153static bool hasGlobalMemorySpace(Attribute memorySpace) {
154 if (!memorySpace)
155 return true;
156 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
157 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
158 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
159 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
160 return false;
161}
162
163static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
164 if (!memorySpace)
165 return false;
166 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
167 return intMemorySpace.getInt() == 3;
168 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
169 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
170 return false;
171}
172
173static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
174 if (!memorySpace)
175 return false;
176 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
177 return intMemorySpace.getInt() == 7;
178 if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
179 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
180 return false;
181}
182
183//===----------------------------------------------------------------------===//
184// RawBuffer*Op
185//===----------------------------------------------------------------------===//
186template <typename T>
187static LogicalResult verifyRawBufferOp(T &op) {
188 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
189 bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
190
191 if (!isGlobal)
192 return op.emitOpError(
193 "buffer ops must operate on a memref in global memory");
194 if (!bufferType.hasRank())
195 return op.emitOpError(
196 "cannot meaningfully buffer_store to an unranked memref");
197 return verifyIndexCount(op, "buffer", bufferType, op.getIndices().size());
198}
199
200LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
201
202LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
203
204LogicalResult RawBufferAtomicFaddOp::verify() {
205 return verifyRawBufferOp(*this);
206}
207
208LogicalResult RawBufferAtomicFmaxOp::verify() {
209 return verifyRawBufferOp(*this);
210}
211
212LogicalResult RawBufferAtomicSmaxOp::verify() {
213 return verifyRawBufferOp(*this);
214}
215
216LogicalResult RawBufferAtomicUminOp::verify() {
217 return verifyRawBufferOp(*this);
218}
219
220LogicalResult RawBufferAtomicCmpswapOp::verify() {
221 return verifyRawBufferOp(*this);
222}
223
224static std::optional<uint32_t> getConstantUint32(Value v) {
225 APInt cst;
226 if (!v.getType().isInteger(32))
227 return std::nullopt;
228 if (matchPattern(v, m_ConstantInt(&cst)))
229 return cst.getZExtValue();
230 return std::nullopt;
231}
232
233template <typename OpType>
234static bool staticallyOutOfBounds(OpType op) {
235 if (!op.getBoundsCheck())
236 return false;
237 MemRefType bufferType = op.getMemref().getType();
238 if (!bufferType.hasStaticShape())
239 return false;
240 int64_t offset;
241 SmallVector<int64_t> strides;
242 if (failed(bufferType.getStridesAndOffset(strides, offset)))
243 return false;
244 int64_t result = offset + op.getIndexOffset().value_or(0);
245 if (op.getSgprOffset()) {
246 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
247 if (!sgprOffset)
248 return false;
249 result += *sgprOffset;
250 }
251 if (strides.size() != op.getIndices().size())
252 return false;
253 int64_t indexVal = 0;
254 for (auto pair : llvm::zip(strides, op.getIndices())) {
255 int64_t stride = std::get<0>(pair);
256 Value idx = std::get<1>(pair);
257 std::optional<uint32_t> idxVal = getConstantUint32(idx);
258 if (!idxVal)
259 return false;
260 indexVal += stride * *idxVal;
261 }
262 result += indexVal;
263 if (result > std::numeric_limits<uint32_t>::max())
264 // Overflow means don't drop
265 return false;
266 return result >= bufferType.getNumElements();
267}
268
269namespace {
270template <typename OpType>
271struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
272 using OpRewritePattern<OpType>::OpRewritePattern;
273
274 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
275 if (!staticallyOutOfBounds(op))
276 return failure();
277 Type loadType = op.getResult().getType();
278 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
279 rw.getZeroAttr(loadType));
280 return success();
281 }
282};
283
284template <typename OpType>
285struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
286 using OpRewritePattern<OpType>::OpRewritePattern;
287
288 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
289 if (!staticallyOutOfBounds(op))
290 return failure();
291
292 rw.eraseOp(op);
293 return success();
294 }
295};
296} // end namespace
297
298void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
299 MLIRContext *context) {
300 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
301}
302
303void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
304 MLIRContext *context) {
305 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
306}
307
308void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
309 RewritePatternSet &results, MLIRContext *context) {
310 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
311}
312
313void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
314 RewritePatternSet &results, MLIRContext *context) {
315 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
316}
317
318void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
319 RewritePatternSet &results, MLIRContext *context) {
320 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
321}
322
323void RawBufferAtomicUminOp::getCanonicalizationPatterns(
324 RewritePatternSet &results, MLIRContext *context) {
325 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
326}
327
328void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
329 RewritePatternSet &results, MLIRContext *context) {
330 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
331 context);
332}
333
334//===----------------------------------------------------------------------===//
335// ScaledExtPackedMatrixOp
336//===----------------------------------------------------------------------===//
337LogicalResult ScaledExtPackedMatrixOp::verify() {
338 int blockSize = getBlockSize();
339 assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
340
341 int firstScaleByte = getFirstScaleByte();
342 int firstScaleLane = getFirstScaleLane();
343 auto sourceType = cast<VectorType>(getSource().getType());
344 Type elementType = sourceType.getElementType();
345 auto floatType = cast<FloatType>(elementType);
346 unsigned bitWidth = floatType.getWidth();
347
348 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
349
350 const bool is_fp8 = bitWidth == 8;
351 const bool is_block_16 = blockSize == 16;
352
353 if (!is_fp8) {
354 if (is_block_16) {
355 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
356 return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
357 "or 1 for f4 and f6.");
358 }
359 } else {
360 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
361 return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
362 "or 2 for f4 and f6.");
363 }
364 }
365 } else {
366 if (is_block_16) {
367 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
368 ((firstScaleLane == 16) && (firstScaleByte == 2));
369 if (!is_valid) {
370 return emitOpError("blockSize of 16 can only have (firstScaleLane, "
371 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
372 }
373 }
374 }
375
376 return success();
377}
378
379//===----------------------------------------------------------------------===//
380// WMMAOp
381//===----------------------------------------------------------------------===//
382
384 IntegerAttr &m, IntegerAttr &n,
385 IntegerAttr &k) {
386 SmallVector<int64_t, 3> dimensions;
387 if (parser.parseDimensionList(dimensions, false, false))
388 return failure();
389 if (dimensions.size() != 3)
390 return parser.emitError(parser.getCurrentLocation())
391 << "expected 3 dimensions in MNK dimension list";
392
393 m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
394 n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
395 k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
396 return success();
397}
398
399LogicalResult WMMAOp::verify() {
400 auto sourceAType = cast<VectorType>(getSourceA().getType());
401 auto sourceBType = cast<VectorType>(getSourceB().getType());
402 auto destType = cast<VectorType>(getDestC().getType());
403
404 Type sourceAElemType = sourceAType.getElementType();
405 Type sourceBElemType = sourceBType.getElementType();
406 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
407 return emitOpError("source vectors have different lengths: ")
408 << sourceAType << " vs. " << sourceBType;
409 }
410
411 bool isDestFloat = destType.getElementType().isFloat();
412 bool isSrcFloat = sourceAElemType.isFloat();
413
414 if (isDestFloat && !isSrcFloat)
415 return emitOpError("expected float sources with float destination");
416 if (!isDestFloat && isSrcFloat)
417 return emitOpError("expected int sources with int destination");
418
419 if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
420 return emitOpError(
421 "source element types must match (except for fp8/bf8) but have ")
422 << sourceAType << " and " << sourceBType;
423 }
424
425 if (isSrcFloat) {
426 if (getClamp())
427 return emitOpError("clamp flag is not supported for float types");
428 if (getUnsignedA() || getUnsignedB())
429 return emitOpError("unsigned flags are not supported for float types");
430 }
431 return success();
432}
433
434//===----------------------------------------------------------------------===//
435// ScaledWMMAOp
436//===----------------------------------------------------------------------===//
437
438LogicalResult ScaledWMMAOp::verify() {
439 // Helper functions for type classification.
440 auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
441 auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
442 auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
443 auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
444 auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
445 auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
446
447 auto sourceAType = cast<VectorType>(getSourceA().getType());
448 auto sourceBType = cast<VectorType>(getSourceB().getType());
449 auto destType = cast<VectorType>(getDestC().getType());
450
451 // Validate source element types are small floats (fp4/fp6/fp8).
452 Type aElemType = sourceAType.getElementType();
453 Type bElemType = sourceBType.getElementType();
454
455 // Validate vector lengths based on dimensions.
456 int64_t m = getM();
457 int64_t aLen = sourceAType.getNumElements();
458 int64_t bLen = sourceBType.getNumElements();
459 int64_t expectedOutLen = (m == 16) ? 8 : 16;
460
461 if (destType.getNumElements() != expectedOutLen)
462 return emitOpError("expected output vector of length ")
463 << expectedOutLen << " but got " << destType.getNumElements();
464
465 if (m == 16) {
466 // For 16×16×128: both A and B must be 64 elements.
467 if (aLen != 64)
468 return emitOpError(
469 "for 16x16x128, sourceA must have 64 elements but got ")
470 << aLen;
471 if (bLen != 64)
472 return emitOpError(
473 "for 16x16x128, sourceB must have 64 elements but got ")
474 << bLen;
475 } else { // m == 32
476 // For 32×16×128: only fp4 is supported, A is 128, B is 64.
477 if (!isF4(aElemType) && !isF4(bElemType))
478 return emitOpError("32x16x128 only supports fp4 element types");
479
480 if (aLen != 128)
481 return emitOpError(
482 "for 32x16x128, sourceA must have 128 elements but got ")
483 << aLen;
484 if (bLen != 64)
485 return emitOpError(
486 "for 32x16x128, sourceB must have 64 elements but got ")
487 << bLen;
488
489 // For 32x16x128, matrix A uses all 32 lanes so a_first_scale_lane must be
490 // 0.
491 if (getAFirstScaleLane() != 0)
492 return emitOpError("for 32x16x128, a_first_scale_lane must be 0");
493 }
494
495 // Validate scale types and their compatibility with matrix element types.
496 auto scaleAType = cast<VectorType>(getScaleA().getType());
497 auto scaleBType = cast<VectorType>(getScaleB().getType());
498 Type scaleAElemType = scaleAType.getElementType();
499 Type scaleBElemType = scaleBType.getElementType();
500
501 // Validate scale element types are valid scale f8 types (E8M0FNU or E4M3FN).
502 if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
503 return emitOpError(
504 "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
505
506 // Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid.
507 if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
508 return success();
509
510 // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M3|E4M3).
511 if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
512 isF4(bElemType) && isE4M3(scaleBElemType))
513 return success();
514
515 // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M3|E4M3), Scale B (E8M0).
516 if (isF4(aElemType) && isE4M3(scaleAElemType) &&
517 (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
518 return success();
519
520 // Matrix A (F4) x Matrix B (F4) with Scale A (E4M3), Scale B (E4M3).
521 if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
522 isE4M3(scaleBElemType))
523 return success();
524
525 // No valid combination matched.
526 return emitOpError("invalid combination of matrix and scale types: ")
527 << "sourceA=" << aElemType << ", scaleA=" << scaleAElemType
528 << ", sourceB=" << bElemType << ", scaleB=" << scaleBElemType;
529}
530
531//===----------------------------------------------------------------------===//
532// MFMAOp
533//===----------------------------------------------------------------------===//
534LogicalResult MFMAOp::verify() {
535 constexpr uint32_t waveSize = 64;
537
538 Type sourceType = getSourceA().getType();
539 Type destType = getDestC().getType();
540
541 Type sourceElem = sourceType, destElem = destType;
542 uint32_t sourceLen = 1, destLen = 1;
543 if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
544 sourceLen = sourceVector.getNumElements();
545 sourceElem = sourceVector.getElementType();
546 }
547 if (auto destVector = dyn_cast<VectorType>(destType)) {
548 destLen = destVector.getNumElements();
549 destElem = destVector.getElementType();
550 }
551
552 Type sourceBType = getSourceB().getType();
553 if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
554 int64_t sourceBLen = 1;
555 Type sourceBElem = sourceBType;
556 if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
557 sourceBLen = sourceBVector.getNumElements();
558 sourceBElem = sourceBVector.getElementType();
559 }
560 if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
561 !sourceBElem.isFloat(4))
562 return emitOpError("expected both source operands to have small-float "
563 "elements if one does");
564 if (sourceLen != sourceBLen)
565 return emitOpError(
566 "expected both small-float source vectors to have the same length");
567 } else {
568 if (sourceType != sourceBType)
569 return emitOpError("expected both non-small-float source operand types "
570 "to match exactly");
571 }
572 // Normalize the wider integer types the compiler expects to i8.
573 if (sourceElem.isInteger(32)) {
574 sourceLen *= 4;
575 sourceElem = b.getI8Type();
576 }
577 if (sourceElem.isInteger(64)) {
578 sourceLen *= 8;
579 sourceElem = b.getI8Type();
580 }
581
582 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
583 if (sourceLen != numSourceElems)
584 return emitOpError("expected " + Twine(numSourceElems) +
585 " source values for this operation but got " +
586 Twine(sourceLen));
587
588 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
589 if (destLen != numDestElems)
590 return emitOpError("expected " + Twine(numDestElems) +
591 " result values for this operation but got " +
592 Twine(destLen));
593
594 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
595 return emitOpError(
596 "double-precision ops do not support permuting lanes of B");
597 if (destElem.isF64() && getCbsz() != 0)
598 return emitOpError(
599 "double-precision ops do not support permuting lanes of A");
600 if (getAbid() >= (1u << getCbsz()))
601 return emitOpError(
602 "block ID for permuting A (abid) must be below 2 ** cbsz");
603
604 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
605 return emitOpError(
606 "negation flags only available for double-precision operations");
607
608 return success();
609}
610
611//===----------------------------------------------------------------------===//
612// SparseMFMAOp
613//===----------------------------------------------------------------------===//
614
615LogicalResult SparseMFMAOp::verify() {
616 constexpr uint32_t waveSize = 64;
617
618 auto sparseType = cast<VectorType>(getSourceA().getType());
619 auto denseType = cast<VectorType>(getSourceB().getType());
620 auto destType = cast<VectorType>(getDestC().getType());
621
622 Type sparseElem = sparseType.getElementType();
623 Type denseElem = denseType.getElementType();
624 int64_t sparseLen = sparseType.getNumElements();
625 int64_t denseLen = denseType.getNumElements();
626 int64_t destLen = destType.getNumElements();
627
628 if (denseLen != 2 * sparseLen)
629 return emitOpError("expected dense source operand to have exactly double "
630 "the number of elements of the sparse source operand");
631
632 // Check that source element types are compatible.
633 // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8).
634 // For other types, element types must match exactly.
635 bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
636 if (!bothFloat8 && sparseElem != denseElem)
637 return emitOpError(
638 "expected source operands to have the same element type");
639
640 // Classify the sparse MFMA variant. The three flavors differ in CBSZ/ABID
641 // handling and in the sparse-index layout:
642 // - gfx942 16-bit: max ABID = 3, sparse idx = vector<4xi8>
643 // - gfx950 16-bit / gfx942 8-bit: max ABID = 1, sparse idx = vector<2xi16>
644 // - gfx950 8-bit: CBSZ/ABID ignored by hw, sparse idx = i32
645 uint32_t m = getM(), k = getK();
646 bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
647 bool is16BitGfx942 =
648 !is8BitSource && ((m == 16 && k == 32) || (m == 32 && k == 16));
649 bool is8BitGfx950 =
650 is8BitSource && ((m == 16 && k == 128) || (m == 32 && k == 64));
651
652 // CBSZ/ABID range check. On gfx950 8-bit the hardware always uses the first
653 // set and ignores these fields, so require zeros in IR. Otherwise ABID is
654 // only meaningful when CBSZ == 0 (when CBSZ != 0 the first set is always
655 // used and ABID is irrelevant, so the verifier accepts any value).
656 if (is8BitGfx950) {
657 if (getCbsz() != 0)
658 return emitOpError(
659 "CBSZ must be 0 for this variant (field is ignored by hardware)");
660 if (getAbid() != 0)
661 return emitOpError(
662 "ABID must be 0 for this variant (field is ignored by hardware)");
663 } else if (getCbsz() == 0) {
664 unsigned maxAbid = is16BitGfx942 ? 3u : 1u;
665 if (getAbid() > maxAbid)
666 return emitOpError("ABID must be in [0, ")
667 << maxAbid << "] for this variant";
668 }
669
670 Type sparseIdxType = getSparseIdx().getType();
671 if (is8BitGfx950) {
672 if (!sparseIdxType.isInteger(32))
673 return emitOpError("expected i32 sparse indices for this variant "
674 "(no internal set structure), but got ")
675 << sparseIdxType;
676 } else {
677 unsigned expectedIdxElems = is16BitGfx942 ? 4 : 2;
678 unsigned expectedIdxBits = is16BitGfx942 ? 8 : 16;
679 auto vecType = dyn_cast<VectorType>(sparseIdxType);
680 if (!vecType || vecType.getNumElements() != expectedIdxElems ||
681 !vecType.getElementType().isInteger(expectedIdxBits))
682 return emitOpError("expected vector<")
683 << expectedIdxElems << "xi" << expectedIdxBits
684 << "> sparse indices for this variant, but got " << sparseIdxType;
685 }
686
687 int64_t expectedSourceElems = (getM() * getK()) / waveSize;
688 if (denseLen != expectedSourceElems)
689 return emitOpError("expected " + Twine(expectedSourceElems) +
690 " source values for this operation but got " +
691 Twine(denseLen));
692
693 int64_t expectedDestElems = (getM() * getN()) / waveSize;
694 if (destLen != expectedDestElems)
695 return emitOpError("expected " + Twine(expectedDestElems) +
696 " result values for this operation but got " +
697 Twine(destLen));
698
699 return success();
700}
701
702//===----------------------------------------------------------------------===//
703// SparseWMMAOp
704//===----------------------------------------------------------------------===//
705
706LogicalResult SparseWMMAOp::verify() {
707 auto sparseType = cast<VectorType>(getSourceA().getType());
708 auto denseType = cast<VectorType>(getSourceB().getType());
709 auto destType = cast<VectorType>(getDestC().getType());
710
711 Type sparseElem = sparseType.getElementType();
712 Type denseElem = denseType.getElementType();
713 Type destElem = destType.getElementType();
714 int64_t sparseLen = sparseType.getNumElements();
715 int64_t denseLen = denseType.getNumElements();
716 int64_t destLen = destType.getNumElements();
717
718 uint32_t m = getM(), n = getN(), k = getK();
719 if ((m != 16) || (n != 16))
720 return emitOpError("expected MxN to be exactly 16x16");
721
722 const bool isWavesize64 = getWave64();
723 const bool isInt4Input = sparseElem.isInteger(4) && denseElem.isInteger(4);
724 const bool isEqualLengthAllowed = isWavesize64 && isInt4Input && k == 32;
725
726 if ((denseLen != 2 * sparseLen) && !isEqualLengthAllowed)
727 return emitOpError("expected dense source operand to have exactly double "
728 "the number of elements of the sparse source operand");
729
730 if (isEqualLengthAllowed && (denseLen != sparseLen))
731 return emitOpError("expected dense source operand to have exactly the "
732 "same the number of elements");
733
734 if (destElem.isInteger()) {
735 if (!(sparseElem.isInteger() && denseElem.isInteger())) {
736 return emitOpError("source operand and destination operands must all be "
737 "either integer or float types");
738 }
739 }
740
741 if (destElem.isFloat()) {
742 if (!(sparseElem.isFloat() && denseElem.isFloat())) {
743 return emitOpError("source operand and destination operands must all be "
744 "either integer or float types");
745 }
746 }
747
748 // Check that source element types are compatible.
749 // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8).
750 // For other types, element types must match exactly.
751 bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
752 if (!bothFloat8 && sparseElem != denseElem)
753 return emitOpError(
754 "expected source operands to have the same element type");
755
756 const int64_t waveSize = isWavesize64 ? 64 : 32;
757
758 int64_t expectedSourceElems = (getM() * getK()) / waveSize;
759 if (denseLen != expectedSourceElems)
760 return emitOpError("expected " + Twine(expectedSourceElems) +
761 " source values for this operation but got " +
762 Twine(denseLen));
763
764 int64_t expectedDestElems = (getM() * getN()) / waveSize;
765 if (destLen != expectedDestElems)
766 return emitOpError("expected " + Twine(expectedDestElems) +
767 " result values for this operation but got " +
768 Twine(destLen));
769
770 return success();
771}
772
773//===----------------------------------------------------------------------===//
774// DotOp
775//===----------------------------------------------------------------------===//
776LogicalResult DotOp::verify() {
777 Type aElem = cast<VectorType>(getSourceA().getType()).getElementType();
778 Type bElem = cast<VectorType>(getSourceB().getType()).getElementType();
779 Type dest = getDestC().getType();
780
781 bool aIsFloat8 = aElem.isFloat(8);
782 bool bIsFloat8 = bElem.isFloat(8);
783 bool aIsInteger = isa<IntegerType>(aElem);
784
785 bool bothFloat8 = aIsFloat8 && bIsFloat8;
786 if (!bothFloat8 && aElem != bElem)
787 return emitOpError(
788 "expected source operands to have the same element type");
789
790 if (aElem.isF16()) {
791 if (!dest.isF32() && !dest.isF16())
792 return emitOpError("expected f32 or f16 accumulator for f16 sources");
793 } else if (aElem.isBF16()) {
794 if (!dest.isF32() && !dest.isBF16())
795 return emitOpError("expected f32 or bf16 accumulator for bf16 sources");
796 } else if (aIsInteger) {
797 if (!dest.isInteger(32))
798 return emitOpError("expected i32 accumulator for integer sources");
799 } else if (aIsFloat8) {
800 if (!dest.isF32())
801 return emitOpError("expected f32 accumulator for fp8 sources");
802 }
803
804 if ((getUnsignedA() || getUnsignedB()) && !aIsInteger)
805 return emitOpError(
806 "unsignedA/unsignedB are only valid for integer source types");
807
808 if (aElem.isInteger(16) && getUnsignedA() != getUnsignedB())
809 return emitOpError(
810 "mixed-sign dot is not supported for 16-bit integer sources");
811
812 if (getClamp()) {
813 bool noClamp = (aElem.isF16() && dest.isF16()) ||
814 (aElem.isBF16() && dest.isBF16()) || aIsFloat8;
815 if (noClamp)
816 return emitOpError(
817 "clamp is not supported for this (source, accumulator) combination");
818 }
819
820 return success();
821}
822
823//===----------------------------------------------------------------------===//
824// DPPOp
825//===----------------------------------------------------------------------===//
826LogicalResult DPPOp::verify() {
827 DPPPerm kind = getKind();
828 Attribute permArgument = getPermArgument().value_or(Attribute{});
829
830 switch (kind) {
831
832 case DPPPerm::quad_perm: {
833 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
834 if (!quadPermAttr || quadPermAttr.size() != 4) {
835 return emitOpError("quad_perm attribute must have exactly 4 elements");
836 }
837 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
838 int32_t num = elem.getInt();
839 if (num < 0 || num > 3) {
840 return emitOpError(
841 "Each element of quad_perm must be in the range [0, 3]");
842 }
843 }
844 } break;
845
846 case DPPPerm::row_shl:
847 case DPPPerm::row_shr:
848 case DPPPerm::row_ror: {
849 if (!permArgument) {
850 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
851 "' value not specified");
852 }
853 if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
854 uint32_t attrValue = intAttr.getInt();
855 if (attrValue < 1 || attrValue > 15) {
856 return emitOpError("Attribute value must be between 1 and 15");
857 }
858 }
859 } break;
860
861 case DPPPerm::wave_shl:
862 case DPPPerm::wave_shr:
863 case DPPPerm::wave_rol:
864 case DPPPerm::wave_ror:
865 case DPPPerm::row_mirror:
866 case DPPPerm::row_half_mirror:
867 case DPPPerm::row_bcast_15:
868 case DPPPerm::row_bcast_31: {
869 if (permArgument && !isa<UnitAttr>(permArgument)) {
870 return emitOpError("Expected unit attribute for permArgument, but found "
871 "non-trivial argument");
872 }
873 break;
874 }
875 }
876 return success();
877}
878
879//===----------------------------------------------------------------------===//
880// PermlaneSwapOp
881//===----------------------------------------------------------------------===//
882LogicalResult PermlaneSwapOp::verify() {
883 unsigned rowLength = getRowLength();
884
885 if (rowLength != 16 && rowLength != 32)
886 return emitOpError("row_length attribute must either be 16 or 32.");
887
888 return success();
889}
890
891/// Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
892static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op,
893 PatternRewriter &rewriter) {
894 if (isa_and_nonnull<LDSBarrierOp>(op->getNextNode())) {
895 rewriter.eraseOp(op);
896 return success();
897 }
898 return failure();
899}
900
901void LDSBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
902 MLIRContext *context) {
904}
905
906//===----------------------------------------------------------------------===//
907// MemoryCounterWaitOp
908//===----------------------------------------------------------------------===//
909
910namespace {
911/// Fuse adjacent memory counter wait ops, taking the minimum value of the
912/// counters.
913struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
914 using Base::Base;
915
916 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
917 PatternRewriter &rewriter) const override {
918 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
919 if (!next)
920 return failure();
921
922 auto setters = {&MemoryCounterWaitOp::setLoad,
923 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
924 &MemoryCounterWaitOp::setExp,
925 &MemoryCounterWaitOp::setTensor};
926 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
927 op.getTensor()};
928 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
929 next.getExp(), next.getTensor()};
930 rewriter.modifyOpInPlace(op, [&] {
931 for (auto [setter, lhs, rhs] :
932 llvm::zip_equal(setters, lhsVals, rhsVals)) {
933 if (lhs && rhs) {
934 (op.*setter)(std::min(*lhs, *rhs));
935 } else if (lhs) {
936 (op.*setter)(*lhs);
937 } else if (rhs) {
938 (op.*setter)(*rhs);
939 }
940 }
941 });
942 rewriter.eraseOp(next);
943 return success();
944 }
945};
946} // namespace
947
948void MemoryCounterWaitOp::getCanonicalizationPatterns(
949 RewritePatternSet &results, MLIRContext *context) {
950 results.add<FuseMemoryCounterWaitOp>(context);
951}
952
953//===----------------------------------------------------------------------===//
954// GatherToLDSOp
955//===----------------------------------------------------------------------===//
956
957LogicalResult GatherToLDSOp::verify() {
958 MemRefType srcType = cast<MemRefType>(getSrc().getType());
959 MemRefType dstType = cast<MemRefType>(getDst().getType());
960
961 if (failed(
962 verifyIndexCount(*this, "source", srcType, getSrcIndices().size())) ||
963 failed(verifyIndexCount(*this, "destination", dstType,
964 getDstIndices().size())))
965 return failure();
966
967 if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
968 return emitOpError("destination type inner most dim must be contiguous");
969
970 auto elemType = srcType.getElementType();
971 // Check $src and $dst element types are the same.
972 if (elemType != dstType.getElementType())
973 return emitOpError("source and destination element types must match");
974
975 // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
976 auto transferType = getTransferType();
977 int transferSize;
978 if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
979 transferSize = vectorTransfer.getNumElements() *
980 vectorTransfer.getElementTypeBitWidth();
981 } else {
982 transferSize = transferType.getIntOrFloatBitWidth();
983 }
984 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
985 return emitOpError(
986 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
987
988 if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
989 !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
990 return emitOpError(
991 "source memory address space must be global or fat raw buffer");
992
993 if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
994 return emitOpError("destination memory address space must be Workgroup");
995
996 return success();
997}
998
999namespace {
1000/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
1001/// information or changes layout, the cast can be skipped.
1002struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
1004
1005 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
1006 PatternRewriter &rewriter) const override {
1007 bool modified = false;
1008 auto foldCast = [&](OpOperand &operand) {
1009 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
1010 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
1011 rewriter.modifyOpInPlace(gatherOp,
1012 [&] { operand.assign(castOp.getSource()); });
1013 modified = true;
1014 }
1015 }
1016 };
1017
1018 foldCast(gatherOp.getSrcMutable());
1019 foldCast(gatherOp.getDstMutable());
1020
1021 return success(modified);
1022 }
1023};
1024} // namespace
1025
1026void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
1027 MLIRContext *context) {
1028 results.add<FoldGatherToLDSOfCast>(context);
1029}
1030
1031//===----------------------------------------------------------------------===//
1032// GlobalLoadAsyncToLDSOp
1033//===----------------------------------------------------------------------===//
1034
1035LogicalResult GlobalLoadAsyncToLDSOp::verify() {
1036 MemRefType srcType = cast<MemRefType>(getSrc().getType());
1037 MemRefType dstType = cast<MemRefType>(getDst().getType());
1038
1039 if (failed(
1040 verifyIndexCount(*this, "source", srcType, getSrcIndices().size())) ||
1041 failed(verifyIndexCount(*this, "destination", dstType,
1042 getDstIndices().size())))
1043 return failure();
1044
1045 if (srcType.getElementType() != dstType.getElementType())
1046 return emitOpError("source and destination element types must match");
1047
1048 Type transferType = getTransferType();
1049 int transferSize;
1050 if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
1051 transferSize = vectorTransfer.getNumElements() *
1052 vectorTransfer.getElementTypeBitWidth();
1053 } else {
1054 transferSize = transferType.getIntOrFloatBitWidth();
1055 }
1056 if (!llvm::is_contained({8, 32, 64, 128}, transferSize))
1057 return emitOpError("transfer type size must be 8, 32, 64, or 128 bits");
1058
1059 if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
1060 return emitOpError("source memory address space must be global");
1061
1062 if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
1063 return emitOpError("destination memory address space must be Workgroup");
1064
1065 return success();
1066}
1067
1068//===----------------------------------------------------------------------===//
1069// TransposeLoadOp
1070//===----------------------------------------------------------------------===//
1071
1072LogicalResult TransposeLoadOp::verify() {
1073 MemRefType srcType = cast<MemRefType>(getSrc().getType());
1074
1075 if (failed(
1076 verifyIndexCount(*this, "source", srcType, getSrcIndices().size())))
1077 return failure();
1078
1079 if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
1080 return emitOpError("source memory address space must be Workgroup");
1081
1082 auto transferType = cast<VectorType>(getType());
1083 size_t numElements = transferType.getNumElements();
1084 size_t elementTypeSize =
1085 transferType.getElementType().getIntOrFloatBitWidth();
1086
1087 // ElementSize -> NumElements
1088 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
1089 {4, 16},
1090 {6, 16},
1091 {8, 8},
1092 {16, 4},
1093 };
1094
1095 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
1096 if (validNumElems == kValidLoadSizeMap.end())
1097 return emitOpError("Unsupported element type size for transpose load: ")
1098 << elementTypeSize << " bits";
1099
1100 if (numElements != validNumElems->second)
1101 return emitOpError(
1102 "Transferring type size mismatch: expected num of elements: ")
1103 << validNumElems->second;
1104
1105 return success();
1106}
1107
1108//===----------------------------------------------------------------------===//
1109// GlobalTransposeLoadOp
1110//===----------------------------------------------------------------------===//
1111
1112LogicalResult GlobalTransposeLoadOp::verify() {
1113 MemRefType srcType = cast<MemRefType>(getSrc().getType());
1114
1115 if (failed(
1116 verifyIndexCount(*this, "source", srcType, getSrcIndices().size())))
1117 return failure();
1118
1119 if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
1120 return emitOpError("source memory address space must be Global");
1121
1122 auto resultType = cast<VectorType>(getType());
1123 size_t numElements = resultType.getNumElements();
1124 size_t elementTypeSize = resultType.getElementType().getIntOrFloatBitWidth();
1125
1126 // ElementSize -> NumElements. Chipset gating (gfx1200 vs gfx1250) is
1127 // enforced in the lowering.
1128 static const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
1129 {4, 16}, // global_load_tr4_b64 (gfx1250+)
1130 {6, 16}, // global_load_tr6_b96 (gfx1250+)
1131 {8, 8}, // global_load_tr_b64 (gfx1200+)
1132 {16, 8}, // global_load_tr_b128 (gfx1200+)
1133 };
1134
1135 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
1136 if (validNumElems == kValidLoadSizeMap.end())
1137 return emitOpError(
1138 "unsupported element type size for global transpose load: ")
1139 << elementTypeSize << " bits";
1140
1141 if (numElements != validNumElems->second)
1142 return emitOpError(
1143 "transferring type size mismatch: expected num of elements: ")
1144 << validNumElems->second;
1145
1146 return success();
1147}
1148
1149//===----------------------------------------------------------------------===//
1150// MakeDmaBaseOp
1151//===----------------------------------------------------------------------===//
1152
1153template <typename BaseOp>
1154static LogicalResult verifyBase(BaseOp op) {
1155 auto ldsType = cast<MemRefType>(op.getLds().getType());
1156 auto globalType = cast<MemRefType>(op.getGlobal().getType());
1157 if (failed(verifyIndexCount(op, "global", globalType,
1158 op.getGlobalIndices().size())) ||
1159 failed(verifyIndexCount(op, "lds", ldsType, op.getLdsIndices().size())))
1160 return failure();
1161
1162 if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
1163 return op.emitOpError(
1164 "lds memref must have workgroup address space attribute.");
1165 if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
1166 return op.emitOpError(
1167 "global memref must have global address space attribute.");
1168
1169 Type elementType = ldsType.getElementType();
1170 unsigned width = elementType.getIntOrFloatBitWidth();
1171
1172 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
1173 return op.emitOpError(
1174 "element type must be 1, 2, 4, or 8 bytes long but type was ")
1175 << width << " bits long.";
1176 return success();
1177}
1178
1179LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
1180
1181//===----------------------------------------------------------------------===//
1182// MakeGatherDmaBaseOp
1183//===----------------------------------------------------------------------===//
1184
1185LogicalResult
1186TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
1187 Type elementType, Type indexType) {
1188 unsigned width = elementType.getIntOrFloatBitWidth();
1189 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
1190 return emitError()
1191 << "element type must be 1, 2, 4, or 8 bytes wide but type "
1192 << elementType << " is " << width / 8 << " bytes wide.";
1193 MLIRContext *ctx = elementType.getContext();
1194 Type i16 = IntegerType::get(ctx, 32);
1195 Type i32 = IntegerType::get(ctx, 16);
1196 if (!llvm::is_contained({i16, i32}, indexType))
1197 return emitError() << "index type must be i16 or i32 but index type is "
1198 << indexType << ".";
1199 return success();
1200}
1201
1202LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
1203
1204//===----------------------------------------------------------------------===//
1205// MakeDmaDescriptorOp
1206//===----------------------------------------------------------------------===//
1207
1208template <typename DescriptorOp>
1209static LogicalResult verifyDescriptorOp(DescriptorOp op) {
1210 ArrayRef<int64_t> globalStaticStrides = op.getGlobalStaticStrides();
1211
1212 if (globalStaticStrides.empty())
1213 return op.emitOpError("strides must not be empty.");
1214 if (globalStaticStrides.back() != 1)
1215 return op.emitOpError("strides for the innermost dimension must be 1.");
1216
1217 ArrayRef<int64_t> globalStaticSizes = op.getGlobalStaticSizes();
1218 size_t rank = globalStaticSizes.size();
1219 if (rank > 5)
1220 return op.emitOpError("tensor and tile must be at most of rank 5.");
1221 if (rank != globalStaticStrides.size())
1222 return op.emitOpError("strides and sizes must have same rank.");
1223
1224 ArrayRef<int64_t> sharedStaticSizes = op.getSharedStaticSizes();
1225 if (rank != sharedStaticSizes.size())
1226 return op.emitOpError("tensor must have same rank as tile.");
1227
1228 unsigned elementTypeWidth = op.getElementTypeWidth();
1229 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
1230 return op.emitOpError(
1231 "element type width must be 1, 2, 4 or 8 bytes, but was ")
1232 << elementTypeWidth << " bits long";
1233
1234 if (!op.getAtomicBarrierAddress() && !op.getAtomicBarrierIndices().empty())
1235 return op.emitOpError(
1236 "atomic barrier indices require an atomic barrier address");
1237
1238 if (Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
1239 auto atomicBarrierAddressType =
1240 cast<MemRefType>(atomicBarrierAddress.getType());
1241 if (failed(verifyIndexCount(op, "atomic barrier", atomicBarrierAddressType,
1242 op.getAtomicBarrierIndices().size())))
1243 return failure();
1244
1245 bool barrierInLDS =
1246 hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace());
1247 if (!barrierInLDS)
1248 return op.emitOpError("atomic barrier address must be in LDS.");
1249 }
1250
1251 if (op.getEarlyTimeout() && !op.getWorkgroupMask())
1252 return op.emitOpError(
1253 "early timeout does not apply when workgroup_mask is not set.");
1254 return success();
1255}
1256
1257template <typename DescriptorOp, typename FoldAdaptor>
1258static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor) {
1259 SmallVector<OpFoldResult> mixedGlobalSizes(op.getMixedGlobalSizes());
1260 SmallVector<OpFoldResult> mixedGlobalStrides(op.getMixedGlobalStrides());
1261 SmallVector<OpFoldResult> mixedSharedSizes(op.getMixedSharedSizes());
1262
1263 if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true,
1264 /*onlyNonZero=*/true)) &&
1265 failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true,
1266 /*onlyNonZero=*/true)) &&
1267 failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true,
1268 /*onlyNonZero=*/true)))
1269 return nullptr;
1270
1271 SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides,
1272 dynamicSharedSizes;
1273 SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides,
1274 staticSharedSizes;
1275
1276 dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes,
1277 staticGlobalSizes);
1278 op.setGlobalStaticSizes(staticGlobalSizes);
1279 op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
1280
1281 dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides,
1282 staticGlobalStrides);
1283 op.setGlobalStaticStrides(staticGlobalStrides);
1284 op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
1285
1286 dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes,
1287 staticSharedSizes);
1288 op.setSharedStaticSizes(staticSharedSizes);
1289 op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
1290 return op.getResult();
1291}
1292
1293LogicalResult MakeDmaDescriptorOp::verify() {
1294 return verifyDescriptorOp(*this);
1295}
1296
1297OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1298 return foldDescriptorOp(*this, adaptor);
1299}
1300
1301//===----------------------------------------------------------------------===//
1302// MakeGatherDmaDescriptorOp
1303//===----------------------------------------------------------------------===//
1304
1305LogicalResult MakeGatherDmaDescriptorOp::verify() {
1306 ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
1307 size_t rank = globalStaticSizes.size();
1308 if (rank > 2)
1309 return emitOpError(
1310 "tensor and tile must be at most of rank two in gather mode.");
1312 Type elementType = cast<VectorType>(indices.getType()).getElementType();
1313 if (elementType != getBase().getType().getIndexType())
1314 return emitOpError("indices' element type must match base's element type.");
1315
1316 return verifyDescriptorOp(*this);
1317}
1318
1319OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1320 return foldDescriptorOp(*this, adaptor);
1321}
1322
1323//===----------------------------------------------------------------------===//
1324// ScaledMFMAOp
1325//===----------------------------------------------------------------------===//
1326
1327namespace {
1328/// Check if the scales input is used in other scaled mfma's while they exist.
1329/// If theyre unused then pack the scales.
1330struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
1332
1333 LogicalResult matchAndRewrite(ScaledMFMAOp op,
1334 PatternRewriter &rewriter) const override {
1335 Location loc = op.getLoc();
1336 auto setOpsel = [&op](unsigned idx, int64_t val) {
1337 switch (idx) {
1338 case 3:
1339 op.setScalesIdxA(val);
1340 break;
1341 case 4:
1342 op.setScalesIdxB(val);
1343 break;
1344 default:
1345 break;
1346 }
1347 };
1348
1349 // For every scale operand of this ScaledMFMAOp, if the scale is produced by
1350 // the extraction of a single scale from some vector, then attempt to
1351 // extract 4 values from that vector instead.
1352 //
1353 // Example: (f8 here means f8E8M0FNU)
1354 // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
1355 // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
1356 // amdgpu.scaled_mfma(%scale[0] * ...
1357 //
1358 // rewrite to:
1359 //
1360 // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
1361 // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
1362 // amdgpu.scaled_mfma(%scale[0-3] * ...
1363 //
1364 // This creates duplicate shape_casts for every use but these will be
1365 // removed in CSE.
1366 for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
1367 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
1368 if (!insertOp) {
1369 return rewriter.notifyMatchFailure(op,
1370 "defining op not a vector.insert");
1371 }
1372 // If the extracted value is not a single scalar, then it has been packed.
1373 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
1374 return rewriter.notifyMatchFailure(
1375 op, "scaled mfma operand already packed");
1376 }
1377
1378 auto extractOp =
1379 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
1380 if (!extractOp) {
1381 return rewriter.notifyMatchFailure(op,
1382 "defining op not a vector.extract");
1383 }
1384
1385 Value scaleSrc = extractOp.getOperand(0);
1386 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
1387 if (!scaleSrcType) {
1388 return rewriter.notifyMatchFailure(op, "not a vector type");
1389 }
1390
1391 // We do not handle dynamic dims yet, assume that the input is padded to
1392 // a static shape now.
1393 if (!scaleSrcType.hasStaticShape()) {
1394 return rewriter.notifyMatchFailure(op,
1395 "dynamic dims not yet supported");
1396 }
1397
1398 int64_t numElements = scaleSrcType.getNumElements();
1399 if (numElements < 4) {
1400 return rewriter.notifyMatchFailure(
1401 op, "do not pack if # of scales less than four");
1402 }
1403
1404 // Find a linearized idx using the size and offsets of the extract op.
1405 auto extractedPos = llvm::to_vector_of<int64_t>(
1406 llvm::reverse(extractOp.getStaticPosition()));
1407 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
1408 int64_t scaleSrcRank = scaleSrcType.getRank();
1409 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
1410 for (int64_t i = 1; i < scaleSrcRank; ++i) {
1411 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
1412 }
1413 int64_t idx = linearize(extractedPos, extractSizes);
1414
1415 // All n scales (where n is the total number of scales) must now be
1416 // extracted in chunks of 4 elements. This is done by dividing the
1417 // original vector of scales into groups of 4 elements
1418 // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
1419 // scale at a particular index are now replaced with an extraction
1420 // of the entire group of 4 elements to which that index belongs.
1421 //
1422 // If the number of scales happens to be indivisible by 4, extract
1423 // the remaining n - m scales in a chunk of 4 elements starting at
1424 // offset n - 4.
1425 int64_t offset = idx - (idx % 4);
1426 int64_t opsel = idx - offset;
1427 int64_t size = 4l;
1428 // Accomdate remaining elements in the case of non-4-divisible vectors.
1429 if (numElements - offset < size) {
1430 opsel = size - (numElements - idx);
1431 offset = numElements - 4l;
1432 }
1433 Type scaleSrcElemType = scaleSrcType.getElementType();
1434 auto newSrcType =
1435 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1436 Value newScaleSrc =
1437 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1438 auto extract = vector::ExtractStridedSliceOp::create(
1439 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1440 ArrayRef{int64_t(1)});
1441 rewriter.modifyOpInPlace(op, [&] {
1442 op->setOperand(opIdx, extract);
1443 setOpsel(opIdx, opsel);
1444 });
1445 }
1446 return success();
1447 }
1448};
1449} // namespace
1450
1451void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
1452 MLIRContext *context) {
1453 results.add<PackScales>(context);
1454}
1455
1456//===----------------------------------------------------------------------===//
1457// In-LDS Barrier Operations (gfx1250+)
1458//===----------------------------------------------------------------------===//
1459
1460template <typename T>
1461static LogicalResult verifyDsBarrierOpCommon(T &op) {
1462 MemRefType memrefType = llvm::cast<MemRefType>(op.getBase().getType());
1463 if (failed(
1464 verifyIndexCount(op, "barrier", memrefType, op.getIndices().size())))
1465 return failure();
1466
1467 if (!hasWorkgroupMemorySpace(memrefType.getMemorySpace()))
1468 return op.emitOpError("barrier must be in workgroup (LDS) memory");
1469
1470 return success();
1471}
1472
1473LogicalResult DsBarrierInitOp::verify() {
1474 return verifyDsBarrierOpCommon(*this);
1475}
1476
1477LogicalResult DsBarrierPollStateOp::verify() {
1478 return verifyDsBarrierOpCommon(*this);
1479}
1480
1481LogicalResult DsAsyncBarrierArriveOp::verify() {
1482 return verifyDsBarrierOpCommon(*this);
1483}
1484
1485LogicalResult DsBarrierArriveOp::verify() {
1486 return verifyDsBarrierOpCommon(*this);
1487}
1488
1489//===----------------------------------------------------------------------===//
1490// GlobalPrefetchOp
1491//===----------------------------------------------------------------------===//
1492
1493LogicalResult GlobalPrefetchOp::verify() {
1494 auto src = cast<MemRefType>(getSrc().getType());
1495
1496 if (failed(verifyIndexCount(*this, "source", src, getIndices().size())))
1497 return failure();
1498
1499 Attribute memSpace = src.getMemorySpace();
1500 if (!memSpace)
1501 return this->emitOpError("the source must have address space attribute");
1502 if (!hasGlobalMemorySpace(memSpace))
1503 return this->emitOpError("the source must reside in global address space");
1504
1505 const LoadTemporalHint temporalHint = getTemporalHint();
1506 const Scope scope = getCacheScope();
1507 const bool isSpeculative = getSpeculative();
1508
1509 // See GFX1250 SPG for a detail explanation
1510 if (isSpeculative && scope == Scope::WGP)
1511 return this->emitOpError(
1512 "does not support speculative prefetch in WGP scope");
1513
1514 // Note that temporal hints are shared between load, store,
1515 // prefetch, etc. instructions. However, some instructions
1516 // operate only with a subset of hints according to the ISA
1517 // documentation. In case of global prefetch, non-temporal (NT)
1518 // and last-use (LU) hints are not used. The extra bits of encoding
1519 // are used to encode speculative or non-speculative instruction behavior
1520 if (llvm::is_contained({LoadTemporalHint::NT, LoadTemporalHint::LU},
1521 temporalHint))
1522 return this->emitOpError("does not support NT and LU modes");
1523
1524 if (llvm::is_contained({LoadTemporalHint::NT_RT, LoadTemporalHint::RT_NT,
1525 LoadTemporalHint::NT_HT},
1526 temporalHint) &&
1527 !isSpeculative) {
1528 return this->emitOpError("operates only in the speculative mode");
1529 }
1530 return success();
1531}
1532
1533#define GET_OP_CLASSES
1534#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyDescriptorOp(DescriptorOp op)
static LogicalResult verifyRawBufferOp(T &op)
static LogicalResult verifyDsBarrierOpCommon(T &op)
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
Definition AMDGPUOps.cpp:86
static LogicalResult verifyIndexCount(OpTy op, StringRef indexName, MemRefType memrefType, int64_t numIndices)
Verifies that the number of indices matches the rank of the indexed memref, emitting an op error ment...
Definition AMDGPUOps.cpp:44
static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, PatternRewriter &rewriter)
Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:47
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:70
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...