MLIR 23.0.0git
AMDGPUOps.cpp
Go to the documentation of this file.
1//===- AMDGPUOps.cpp - MLIR AMDGPU dialect operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMDGPU dialect operations, their verifiers, and
10// their canonicalizations.
11//
12//===----------------------------------------------------------------------===//
13
15
23#include "mlir/IR/Builders.h"
25#include "mlir/IR/Diagnostics.h"
26#include "mlir/IR/Matchers.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallVector.h"
32
33#include <algorithm>
34#include <cstdint>
35#include <limits>
36#include <optional>
37
38using namespace mlir;
39using namespace mlir::amdgpu;
40
41//===----------------------------------------------------------------------===//
42// 8-bit float ops
43//===----------------------------------------------------------------------===//
44LogicalResult PackedTrunc2xFp8Op::verify() {
45 if (getExisting() && getExisting().getType() != getResult().getType())
46 return emitOpError("existing values must have same type as result");
47 return success();
48}
49
50LogicalResult PackedStochRoundFp8Op::verify() {
51 if (getExisting() && getExisting().getType() != getResult().getType())
52 return emitOpError("existing values must have same type as result");
53 return success();
54}
55
56//===----------------------------------------------------------------------===//
57// mxfp float ops
58//===----------------------------------------------------------------------===//
59LogicalResult PackedScaledTruncOp::verify() {
60 if (getExisting() && getExisting().getType() != getResult().getType())
61 return emitOpError("existing values must have same type as result");
62 return success();
63}
64
65//===----------------------------------------------------------------------===//
66// FatRawBufferCastOp
67//===----------------------------------------------------------------------===//
68
69/// Convert the type `source` to one with the same sizes and strides - and
70/// offset, unless `stripOffset` is true, in which case the offset is reset to
71/// 0, if the offset should be reset but the layout of `source` isn't either the
72/// identity layout or a strided layout, this function fails.
73static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
74 bool resetOffset) {
75 MLIRContext *ctx = source.getContext();
76 MemRefType::Builder mb(source);
78 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
79 MemRefLayoutAttrInterface layout = source.getLayout();
80 if (resetOffset && !layout.isIdentity()) {
81 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
82 if (!stridedLayout)
83 return failure();
84 MemRefLayoutAttrInterface newLayout =
85 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
86 // Special case: if resetting the offset causes the strided layout to become
87 // the identity layout, then reset to the identity layout.
88 // TODO: this'll get a lot simpler when we have the contiguous layout.
89 SmallVector<int64_t> stridesIfIdentity;
90 if (source.hasStaticShape()) {
91 stridesIfIdentity = computeSuffixProduct(source.getShape());
92 } else if (source.getRank() <= 1) {
93 stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
94 }
95 if (stridesIfIdentity == stridedLayout.getStrides()) {
96 newLayout = AffineMapAttr::get(
97 AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
98 }
99 mb.setLayout(newLayout);
100 }
101 return (MemRefType)(mb);
102}
103
104LogicalResult FatRawBufferCastOp::inferReturnTypes(
105 MLIRContext *context, std::optional<Location> location, ValueRange operands,
106 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
107 SmallVectorImpl<Type> &inferredReturnTypes) {
108 Adaptor adaptor(operands, attributes, properties, regions);
109 auto sourceType =
110 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
111 if (!sourceType)
112 return failure();
113 FailureOr<MemRefType> resultType =
114 getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
115 if (failed(resultType))
116 return failure();
117 inferredReturnTypes = SmallVector<Type>{*resultType};
118 return success();
119}
120
121FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(OpBuilder &builder,
122 int resultIndex,
123 int dim) {
124 assert(resultIndex == 0 && "FatRawBufferCastOp has a single result");
125 return memref::getMixedSize(builder, getLoc(), getSource(), dim);
126}
127
128LogicalResult FatRawBufferCastOp::verify() {
129 FailureOr<MemRefType> expectedResultType =
130 getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
131 if (failed(expectedResultType))
132 return emitOpError("source type ")
133 << getSource().getType() << " can't have its offset reset";
134 if (getResult().getType() != *expectedResultType)
135 return emitOpError("expected result type to be ")
136 << *expectedResultType << " but got " << getResult().getType();
137 return success();
138}
139
140static bool hasGlobalMemorySpace(Attribute memorySpace) {
141 if (!memorySpace)
142 return true;
143 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
144 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
145 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
146 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
147 return false;
148}
149
150static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
151 if (!memorySpace)
152 return false;
153 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
154 return intMemorySpace.getInt() == 3;
155 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
156 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
157 return false;
158}
159
160static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
161 if (!memorySpace)
162 return false;
163 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
164 return intMemorySpace.getInt() == 7;
165 if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
166 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
167 return false;
168}
169
170//===----------------------------------------------------------------------===//
171// RawBuffer*Op
172//===----------------------------------------------------------------------===//
173template <typename T>
174static LogicalResult verifyRawBufferOp(T &op) {
175 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
176 bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
177
178 if (!isGlobal)
179 return op.emitOpError(
180 "Buffer ops must operate on a memref in global memory");
181 if (!bufferType.hasRank())
182 return op.emitOpError(
183 "Cannot meaningfully buffer_store to an unranked memref");
184 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
185 return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
186 " indices to memref");
187 return success();
188}
189
190LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
191
192LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
193
194LogicalResult RawBufferAtomicFaddOp::verify() {
195 return verifyRawBufferOp(*this);
196}
197
198LogicalResult RawBufferAtomicFmaxOp::verify() {
199 return verifyRawBufferOp(*this);
200}
201
202LogicalResult RawBufferAtomicSmaxOp::verify() {
203 return verifyRawBufferOp(*this);
204}
205
206LogicalResult RawBufferAtomicUminOp::verify() {
207 return verifyRawBufferOp(*this);
208}
209
210LogicalResult RawBufferAtomicCmpswapOp::verify() {
211 return verifyRawBufferOp(*this);
212}
213
214static std::optional<uint32_t> getConstantUint32(Value v) {
215 APInt cst;
216 if (!v.getType().isInteger(32))
217 return std::nullopt;
218 if (matchPattern(v, m_ConstantInt(&cst)))
219 return cst.getZExtValue();
220 return std::nullopt;
221}
222
223template <typename OpType>
224static bool staticallyOutOfBounds(OpType op) {
225 if (!op.getBoundsCheck())
226 return false;
227 MemRefType bufferType = op.getMemref().getType();
228 if (!bufferType.hasStaticShape())
229 return false;
230 int64_t offset;
231 SmallVector<int64_t> strides;
232 if (failed(bufferType.getStridesAndOffset(strides, offset)))
233 return false;
234 int64_t result = offset + op.getIndexOffset().value_or(0);
235 if (op.getSgprOffset()) {
236 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
237 if (!sgprOffset)
238 return false;
239 result += *sgprOffset;
240 }
241 if (strides.size() != op.getIndices().size())
242 return false;
243 int64_t indexVal = 0;
244 for (auto pair : llvm::zip(strides, op.getIndices())) {
245 int64_t stride = std::get<0>(pair);
246 Value idx = std::get<1>(pair);
247 std::optional<uint32_t> idxVal = getConstantUint32(idx);
248 if (!idxVal)
249 return false;
250 indexVal += stride * *idxVal;
251 }
252 result += indexVal;
253 if (result > std::numeric_limits<uint32_t>::max())
254 // Overflow means don't drop
255 return false;
256 return result >= bufferType.getNumElements();
257}
258
259namespace {
260template <typename OpType>
261struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
262 using OpRewritePattern<OpType>::OpRewritePattern;
263
264 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
265 if (!staticallyOutOfBounds(op))
266 return failure();
267 Type loadType = op.getResult().getType();
268 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
269 rw.getZeroAttr(loadType));
270 return success();
271 }
272};
273
274template <typename OpType>
275struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
276 using OpRewritePattern<OpType>::OpRewritePattern;
277
278 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
279 if (!staticallyOutOfBounds(op))
280 return failure();
281
282 rw.eraseOp(op);
283 return success();
284 }
285};
286} // end namespace
287
288void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
289 MLIRContext *context) {
290 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
291}
292
293void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
294 MLIRContext *context) {
295 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
296}
297
298void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
299 RewritePatternSet &results, MLIRContext *context) {
300 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
301}
302
303void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
304 RewritePatternSet &results, MLIRContext *context) {
305 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
306}
307
308void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
309 RewritePatternSet &results, MLIRContext *context) {
310 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
311}
312
313void RawBufferAtomicUminOp::getCanonicalizationPatterns(
314 RewritePatternSet &results, MLIRContext *context) {
315 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
316}
317
318void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
319 RewritePatternSet &results, MLIRContext *context) {
320 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
321 context);
322}
323
324//===----------------------------------------------------------------------===//
325// ScaledExtPackedMatrixOp
326//===----------------------------------------------------------------------===//
327LogicalResult ScaledExtPackedMatrixOp::verify() {
328 int blockSize = getBlockSize();
329 assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
330
331 int firstScaleByte = getFirstScaleByte();
332 int firstScaleLane = getFirstScaleLane();
333 auto sourceType = cast<VectorType>(getSource().getType());
334 Type elementType = sourceType.getElementType();
335 auto floatType = cast<FloatType>(elementType);
336 unsigned bitWidth = floatType.getWidth();
337
338 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
339
340 const bool is_fp8 = bitWidth == 8;
341 const bool is_block_16 = blockSize == 16;
342
343 if (!is_fp8) {
344 if (is_block_16) {
345 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
346 return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
347 "or 1 for f4 and f6.");
348 }
349 } else {
350 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
351 return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
352 "or 2 for f4 and f6.");
353 }
354 }
355 } else {
356 if (is_block_16) {
357 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
358 ((firstScaleLane == 16) && (firstScaleByte == 2));
359 if (!is_valid) {
360 return emitOpError("blockSize of 16 can only have (firstScaleLane, "
361 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
362 }
363 }
364 }
365
366 return success();
367}
368
369//===----------------------------------------------------------------------===//
370// WMMAOp
371//===----------------------------------------------------------------------===//
372
374 IntegerAttr &m, IntegerAttr &n,
375 IntegerAttr &k) {
376 SmallVector<int64_t, 3> dimensions;
377 if (parser.parseDimensionList(dimensions, false, false))
378 return failure();
379 if (dimensions.size() != 3)
380 return parser.emitError(parser.getCurrentLocation())
381 << "expected 3 dimensions in MNK dimension list";
382
383 m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
384 n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
385 k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
386 return success();
387}
388
389LogicalResult WMMAOp::verify() {
390 auto sourceAType = cast<VectorType>(getSourceA().getType());
391 auto sourceBType = cast<VectorType>(getSourceB().getType());
392 auto destType = cast<VectorType>(getDestC().getType());
393
394 Type sourceAElemType = sourceAType.getElementType();
395 Type sourceBElemType = sourceBType.getElementType();
396 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
397 return emitOpError("source vectors have different lengths: ")
398 << sourceAType << " vs. " << sourceBType;
399 }
400
401 bool isDestFloat = destType.getElementType().isFloat();
402 bool isSrcFloat = sourceAElemType.isFloat();
403
404 if (isDestFloat && !isSrcFloat)
405 return emitOpError("expected float sources with float destination");
406 if (!isDestFloat && isSrcFloat)
407 return emitOpError("expected int sources with int destination");
408
409 if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
410 return emitOpError(
411 "source element types must match (except for fp8/bf8) but have ")
412 << sourceAType << " and " << sourceBType;
413 }
414
415 if (isSrcFloat) {
416 if (getClamp())
417 return emitOpError("clamp flag is not supported for float types");
418 if (getUnsignedA() || getUnsignedB())
419 return emitOpError("unsigned flags are not supported for float types");
420 }
421 return success();
422}
423
424//===----------------------------------------------------------------------===//
425// ScaledWMMAOp
426//===----------------------------------------------------------------------===//
427
428LogicalResult ScaledWMMAOp::verify() {
429 // Helper functions for type classification.
430 auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
431 auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
432 auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
433 auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
434 auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
435 auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
436
437 auto sourceAType = cast<VectorType>(getSourceA().getType());
438 auto sourceBType = cast<VectorType>(getSourceB().getType());
439 auto destType = cast<VectorType>(getDestC().getType());
440
441 // Validate source element types are small floats (fp4/fp6/fp8).
442 Type aElemType = sourceAType.getElementType();
443 Type bElemType = sourceBType.getElementType();
444
445 // Validate vector lengths based on dimensions.
446 int64_t m = getM();
447 int64_t aLen = sourceAType.getNumElements();
448 int64_t bLen = sourceBType.getNumElements();
449 int64_t expectedOutLen = (m == 16) ? 8 : 16;
450
451 if (destType.getNumElements() != expectedOutLen)
452 return emitOpError("expected output vector of length ")
453 << expectedOutLen << " but got " << destType.getNumElements();
454
455 if (m == 16) {
456 // For 16×16×128: both A and B must be 64 elements.
457 if (aLen != 64)
458 return emitOpError(
459 "for 16x16x128, sourceA must have 64 elements but got ")
460 << aLen;
461 if (bLen != 64)
462 return emitOpError(
463 "for 16x16x128, sourceB must have 64 elements but got ")
464 << bLen;
465 } else { // m == 32
466 // For 32×16×128: only fp4 is supported, A is 128, B is 64.
467 if (!isF4(aElemType) && !isF4(bElemType))
468 return emitOpError("32x16x128 only supports fp4 element types");
469
470 if (aLen != 128)
471 return emitOpError(
472 "for 32x16x128, sourceA must have 128 elements but got ")
473 << aLen;
474 if (bLen != 64)
475 return emitOpError(
476 "for 32x16x128, sourceB must have 64 elements but got ")
477 << bLen;
478
479 // For 32x16x128, matrix A uses all 32 lanes so a_first_scale_lane must be
480 // 0.
481 if (getAFirstScaleLane() != 0)
482 return emitOpError("for 32x16x128, a_first_scale_lane must be 0");
483 }
484
485 // Validate scale types and their compatibility with matrix element types.
486 auto scaleAType = cast<VectorType>(getScaleA().getType());
487 auto scaleBType = cast<VectorType>(getScaleB().getType());
488 Type scaleAElemType = scaleAType.getElementType();
489 Type scaleBElemType = scaleBType.getElementType();
490
491 // Validate scale element types are valid scale f8 types (E8M0FNU or E4M3FN).
492 if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
493 return emitOpError(
494 "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
495
496 // Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid.
497 if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
498 return success();
499
500 // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M3|E4M3).
501 if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
502 isF4(bElemType) && isE4M3(scaleBElemType))
503 return success();
504
505 // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M3|E4M3), Scale B (E8M0).
506 if (isF4(aElemType) && isE4M3(scaleAElemType) &&
507 (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
508 return success();
509
510 // Matrix A (F4) x Matrix B (F4) with Scale A (E4M3), Scale B (E4M3).
511 if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
512 isE4M3(scaleBElemType))
513 return success();
514
515 // No valid combination matched.
516 return emitOpError("invalid combination of matrix and scale types: ")
517 << "sourceA=" << aElemType << ", scaleA=" << scaleAElemType
518 << ", sourceB=" << bElemType << ", scaleB=" << scaleBElemType;
519}
520
521//===----------------------------------------------------------------------===//
522// MFMAOp
523//===----------------------------------------------------------------------===//
524LogicalResult MFMAOp::verify() {
525 constexpr uint32_t waveSize = 64;
527
528 Type sourceType = getSourceA().getType();
529 Type destType = getDestC().getType();
530
531 Type sourceElem = sourceType, destElem = destType;
532 uint32_t sourceLen = 1, destLen = 1;
533 if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
534 sourceLen = sourceVector.getNumElements();
535 sourceElem = sourceVector.getElementType();
536 }
537 if (auto destVector = dyn_cast<VectorType>(destType)) {
538 destLen = destVector.getNumElements();
539 destElem = destVector.getElementType();
540 }
541
542 Type sourceBType = getSourceB().getType();
543 if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
544 int64_t sourceBLen = 1;
545 Type sourceBElem = sourceBType;
546 if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
547 sourceBLen = sourceBVector.getNumElements();
548 sourceBElem = sourceBVector.getElementType();
549 }
550 if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
551 !sourceBElem.isFloat(4))
552 return emitOpError("expected both source operands to have small-float "
553 "elements if one does");
554 if (sourceLen != sourceBLen)
555 return emitOpError(
556 "expected both small-float source vectors to have the same length");
557 } else {
558 if (sourceType != sourceBType)
559 return emitOpError("expected both non-small-float source operand types "
560 "to match exactly");
561 }
562 // Normalize the wider integer types the compiler expects to i8.
563 if (sourceElem.isInteger(32)) {
564 sourceLen *= 4;
565 sourceElem = b.getI8Type();
566 }
567 if (sourceElem.isInteger(64)) {
568 sourceLen *= 8;
569 sourceElem = b.getI8Type();
570 }
571
572 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
573 if (sourceLen != numSourceElems)
574 return emitOpError("expected " + Twine(numSourceElems) +
575 " source values for this operation but got " +
576 Twine(sourceLen));
577
578 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
579 if (destLen != numDestElems)
580 return emitOpError("expected " + Twine(numDestElems) +
581 " result values for this operation but got " +
582 Twine(destLen));
583
584 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
585 return emitOpError(
586 "double-precision ops do not support permuting lanes of B");
587 if (destElem.isF64() && getCbsz() != 0)
588 return emitOpError(
589 "double-precision ops do not support permuting lanes of A");
590 if (getAbid() >= (1u << getCbsz()))
591 return emitOpError(
592 "block ID for permuting A (abid) must be below 2 ** cbsz");
593
594 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
595 return emitOpError(
596 "negation flags only available for double-precision operations");
597
598 return success();
599}
600
601//===----------------------------------------------------------------------===//
602// SparseMFMAOp
603//===----------------------------------------------------------------------===//
604
605LogicalResult SparseMFMAOp::verify() {
606 constexpr uint32_t waveSize = 64;
607
608 auto sparseType = cast<VectorType>(getSourceA().getType());
609 auto denseType = cast<VectorType>(getSourceB().getType());
610 auto destType = cast<VectorType>(getDestC().getType());
611
612 Type sparseElem = sparseType.getElementType();
613 Type denseElem = denseType.getElementType();
614 int64_t sparseLen = sparseType.getNumElements();
615 int64_t denseLen = denseType.getNumElements();
616 int64_t destLen = destType.getNumElements();
617
618 if (denseLen != 2 * sparseLen)
619 return emitOpError("expected dense source operand to have exactly double "
620 "the number of elements of the sparse source operand");
621
622 // Check that source element types are compatible.
623 // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8).
624 // For other types, element types must match exactly.
625 bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
626 if (!bothFloat8 && sparseElem != denseElem)
627 return emitOpError(
628 "expected source operands to have the same element type");
629
630 // When CBSZ == 0, ABID selects the index set within the sparse index VGPR.
631 // When CBSZ != 0, the first index set is always used (ABID ignored).
632 bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
633 // 8-bit source: ABID selects one of two 16-bit index sets.
634 if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
635 return emitOpError("ABID must be 0 or 1 for 8-bit source data");
636 // 16-bit source: ABID selects one of four 8-bit index sets (0-3 all valid).
637 if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
638 return emitOpError("ABID must be between 0 and 3 for 16-bit source data");
639
640 // Validate sparseIdx type matches source element type.
641 auto sparseIdxType = cast<VectorType>(getSparseIdx().getType());
642 if (is8BitSource) {
643 // 8-bit source data requires vector<2xi16> sparse indices.
644 if (sparseIdxType.getNumElements() != 2 ||
645 !sparseIdxType.getElementType().isInteger(16))
646 return emitOpError("expected vector<2xi16> sparse indices for 8-bit "
647 "source data, but got ")
648 << getSparseIdx().getType();
649 } else {
650 // 16-bit source data requires vector<4xi8> sparse indices.
651 if (sparseIdxType.getNumElements() != 4 ||
652 !sparseIdxType.getElementType().isInteger(8))
653 return emitOpError("expected vector<4xi8> sparse indices for 16-bit "
654 "source data, but got ")
655 << getSparseIdx().getType();
656 }
657
658 int64_t expectedSourceElems = (getM() * getK()) / waveSize;
659 if (denseLen != expectedSourceElems)
660 return emitOpError("expected " + Twine(expectedSourceElems) +
661 " source values for this operation but got " +
662 Twine(denseLen));
663
664 int64_t expectedDestElems = (getM() * getN()) / waveSize;
665 if (destLen != expectedDestElems)
666 return emitOpError("expected " + Twine(expectedDestElems) +
667 " result values for this operation but got " +
668 Twine(destLen));
669
670 return success();
671}
672
673//===----------------------------------------------------------------------===//
674// DPPOp
675//===----------------------------------------------------------------------===//
676LogicalResult DPPOp::verify() {
677 Type srcType = getSrc().getType();
678 if (srcType.getIntOrFloatBitWidth() > 64) {
679 return emitOpError("integer and floating point types larger than 64 bits "
680 "are not supported");
681 }
682
683 DPPPerm kind = getKind();
684 Attribute permArgument = getPermArgument().value_or(Attribute{});
685
686 switch (kind) {
687
688 case DPPPerm::quad_perm: {
689 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
690 if (!quadPermAttr || quadPermAttr.size() != 4) {
691 return emitOpError("quad_perm attribute must have exactly 4 elements");
692 }
693 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
694 int32_t num = elem.getInt();
695 if (num < 0 || num > 3) {
696 return emitOpError(
697 "Each element of quad_perm must be in the range [0, 3]");
698 }
699 }
700 } break;
701
702 case DPPPerm::row_shl:
703 case DPPPerm::row_shr:
704 case DPPPerm::row_ror: {
705 if (!permArgument) {
706 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
707 "' value not specified");
708 }
709 if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
710 uint32_t attrValue = intAttr.getInt();
711 if (attrValue < 1 || attrValue > 15) {
712 return emitOpError("Attribute value must be between 1 and 15");
713 }
714 }
715 } break;
716
717 case DPPPerm::wave_shl:
718 case DPPPerm::wave_shr:
719 case DPPPerm::wave_rol:
720 case DPPPerm::wave_ror:
721 case DPPPerm::row_mirror:
722 case DPPPerm::row_half_mirror:
723 case DPPPerm::row_bcast_15:
724 case DPPPerm::row_bcast_31: {
725 if (permArgument && !isa<UnitAttr>(permArgument)) {
726 return emitOpError("Expected unit attribute for permArgument, but found "
727 "non-trivial argument");
728 }
729 break;
730 }
731 }
732 return success();
733}
734
735//===----------------------------------------------------------------------===//
736// PermlaneSwapOp
737//===----------------------------------------------------------------------===//
738LogicalResult PermlaneSwapOp::verify() {
739 unsigned rowLength = getRowLength();
740
741 if (rowLength != 16 && rowLength != 32)
742 return emitOpError("row_length attribute must either be 16 or 32.");
743
744 return success();
745}
746
747/// Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
748static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op,
749 PatternRewriter &rewriter) {
750 if (isa_and_nonnull<LDSBarrierOp>(op->getNextNode())) {
751 rewriter.eraseOp(op);
752 return success();
753 }
754 return failure();
755}
756
757void LDSBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
758 MLIRContext *context) {
760}
761
762//===----------------------------------------------------------------------===//
763// MemoryCounterWaitOp
764//===----------------------------------------------------------------------===//
765
766namespace {
767/// Fuse adjacent memory counter wait ops, taking the minimum value of the
768/// counters.
769struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
770 using Base::Base;
771
772 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
773 PatternRewriter &rewriter) const override {
774 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
775 if (!next)
776 return failure();
777
778 auto setters = {&MemoryCounterWaitOp::setLoad,
779 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
780 &MemoryCounterWaitOp::setExp,
781 &MemoryCounterWaitOp::setTensor};
782 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
783 op.getTensor()};
784 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
785 next.getExp(), next.getTensor()};
786 rewriter.modifyOpInPlace(op, [&] {
787 for (auto [setter, lhs, rhs] :
788 llvm::zip_equal(setters, lhsVals, rhsVals)) {
789 if (lhs && rhs) {
790 (op.*setter)(std::min(*lhs, *rhs));
791 } else if (lhs) {
792 (op.*setter)(*lhs);
793 } else if (rhs) {
794 (op.*setter)(*rhs);
795 }
796 }
797 });
798 rewriter.eraseOp(next);
799 return success();
800 }
801};
802} // namespace
803
804void MemoryCounterWaitOp::getCanonicalizationPatterns(
805 RewritePatternSet &results, MLIRContext *context) {
806 results.add<FuseMemoryCounterWaitOp>(context);
807}
808
809//===----------------------------------------------------------------------===//
810// GatherToLDSOp
811//===----------------------------------------------------------------------===//
812
813LogicalResult GatherToLDSOp::verify() {
814 MemRefType srcType = cast<MemRefType>(getSrc().getType());
815 MemRefType dstType = cast<MemRefType>(getDst().getType());
816
817 if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
818 return emitOpError("destination type inner most dim must be contiguous");
819
820 auto elemType = srcType.getElementType();
821 // Check $src and $dst element types are the same.
822 if (elemType != dstType.getElementType())
823 return emitOpError("source and destination element types must match");
824
825 // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
826 auto transferType = getTransferType();
827 int transferSize;
828 if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
829 transferSize = vectorTransfer.getNumElements() *
830 vectorTransfer.getElementTypeBitWidth();
831 } else {
832 transferSize = transferType.getIntOrFloatBitWidth();
833 }
834 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
835 return emitOpError(
836 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
837
838 if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
839 !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
840 return emitOpError(
841 "source memory address space must be global or fat raw buffer");
842
843 if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
844 return emitOpError("destination memory address space must be Workgroup");
845
846 return success();
847}
848
849namespace {
850/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
851/// information or changes layout, the cast can be skipped.
852struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
854
855 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
856 PatternRewriter &rewriter) const override {
857 bool modified = false;
858 auto foldCast = [&](OpOperand &operand) {
859 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
860 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
861 rewriter.modifyOpInPlace(gatherOp,
862 [&] { operand.assign(castOp.getSource()); });
863 modified = true;
864 }
865 }
866 };
867
868 foldCast(gatherOp.getSrcMutable());
869 foldCast(gatherOp.getDstMutable());
870
871 return success(modified);
872 }
873};
874} // namespace
875
876void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
877 MLIRContext *context) {
878 results.add<FoldGatherToLDSOfCast>(context);
879}
880
881//===----------------------------------------------------------------------===//
882// TransposeLoadOp
883//===----------------------------------------------------------------------===//
884
885LogicalResult TransposeLoadOp::verify() {
886 MemRefType srcType = cast<MemRefType>(getSrc().getType());
887
888 if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
889 return emitOpError("source memory address space must be Workgroup");
890
891 auto transferType = cast<VectorType>(getType());
892 size_t numElements = transferType.getNumElements();
893 size_t elementTypeSize =
894 transferType.getElementType().getIntOrFloatBitWidth();
895
896 // ElementSize -> NumElements
897 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
898 {4, 16},
899 {6, 16},
900 {8, 8},
901 {16, 4},
902 };
903
904 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
905 if (validNumElems == kValidLoadSizeMap.end())
906 return emitOpError("Unsupported element type size for transpose load: ")
907 << elementTypeSize << " bits";
908
909 if (numElements != validNumElems->second)
910 return emitOpError(
911 "Transferring type size mismatch: expected num of elements: ")
912 << validNumElems->second;
913
914 return success();
915}
916
917//===----------------------------------------------------------------------===//
918// MakeDmaBaseOp
919//===----------------------------------------------------------------------===//
920
921template <typename BaseOp>
922static LogicalResult verifyBase(BaseOp op) {
923 auto ldsType = cast<MemRefType>(op.getLds().getType());
924 auto globalType = cast<MemRefType>(op.getGlobal().getType());
925 if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
926 return op.emitOpError(
927 "lds memref must have workgroup address space attribute.");
928 if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
929 return op.emitOpError(
930 "global memref must have global address space attribute.");
931
932 Type elementType = ldsType.getElementType();
933 unsigned width = elementType.getIntOrFloatBitWidth();
934
935 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
936 return op.emitOpError(
937 "element type must be 1, 2, 4, or 8 bytes long but type was ")
938 << width << " bits long.";
939 return success();
940}
941
942LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
943
944//===----------------------------------------------------------------------===//
945// MakeGatherDmaBaseOp
946//===----------------------------------------------------------------------===//
947
948LogicalResult
949TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
950 Type elementType, Type indexType) {
951 unsigned width = elementType.getIntOrFloatBitWidth();
952 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
953 return emitError()
954 << "element type must be 1, 2, 4, or 8 bytes wide but type "
955 << elementType << " is " << width / 8 << " bytes wide.";
956 MLIRContext *ctx = elementType.getContext();
957 Type i16 = IntegerType::get(ctx, 32);
958 Type i32 = IntegerType::get(ctx, 16);
959 if (!llvm::is_contained({i16, i32}, indexType))
960 return emitError() << "index type must be i16 or i32 but index type is "
961 << indexType << ".";
962 return success();
963}
964
965LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
966
967//===----------------------------------------------------------------------===//
968// MakeDmaDescriptorOp
969//===----------------------------------------------------------------------===//
970
971template <typename DescriptorOp>
972static LogicalResult verifyDescriptorOp(DescriptorOp op) {
973 ArrayRef<int64_t> globalStaticStrides = op.getGlobalStaticStrides();
974
975 if (globalStaticStrides.empty())
976 return op.emitOpError("strides must not be empty.");
977 if (globalStaticStrides.back() != 1)
978 return op.emitOpError("strides for the innermost dimension must be 1.");
979
980 ArrayRef<int64_t> globalStaticSizes = op.getGlobalStaticSizes();
981 size_t rank = globalStaticSizes.size();
982 if (rank > 5)
983 return op.emitOpError("tensor and tile must be at most of rank 5.");
984 if (rank != globalStaticStrides.size())
985 return op.emitOpError("strides and sizes must have same rank.");
986
987 ArrayRef<int64_t> sharedStaticSizes = op.getSharedStaticSizes();
988 if (rank != sharedStaticSizes.size())
989 return op.emitOpError("tensor must have same rank as tile.");
990
991 unsigned elementTypeWidth = op.getElementTypeWidth();
992 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
993 return op.emitOpError(
994 "element type width must be 1, 2, 4 or 8 bytes, but was ")
995 << elementTypeWidth << " bits long";
996
997 if (Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
998 auto atomicBarrierAddressType =
999 cast<MemRefType>(atomicBarrierAddress.getType());
1000 bool barrierInLDS =
1001 hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace());
1002 if (!barrierInLDS)
1003 return op.emitOpError("atomic barrier address must be in LDS.");
1004 }
1005
1006 if (op.getEarlyTimeout() && !op.getWorkgroupMask())
1007 return op.emitOpError(
1008 "early timeout does not apply when workgroup_mask is not set.");
1009 return success();
1010}
1011
1012template <typename DescriptorOp, typename FoldAdaptor>
1013static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor) {
1014 SmallVector<OpFoldResult> mixedGlobalSizes(op.getMixedGlobalSizes());
1015 SmallVector<OpFoldResult> mixedGlobalStrides(op.getMixedGlobalStrides());
1016 SmallVector<OpFoldResult> mixedSharedSizes(op.getMixedSharedSizes());
1017
1018 if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true,
1019 /*onlyNonZero=*/true)) &&
1020 failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true,
1021 /*onlyNonZero=*/true)) &&
1022 failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true,
1023 /*onlyNonZero=*/true)))
1024 return nullptr;
1025
1026 SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides,
1027 dynamicSharedSizes;
1028 SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides,
1029 staticSharedSizes;
1030
1031 dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes,
1032 staticGlobalSizes);
1033 op.setGlobalStaticSizes(staticGlobalSizes);
1034 op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
1035
1036 dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides,
1037 staticGlobalStrides);
1038 op.setGlobalStaticStrides(staticGlobalStrides);
1039 op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
1040
1041 dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes,
1042 staticSharedSizes);
1043 op.setSharedStaticSizes(staticSharedSizes);
1044 op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
1045 return op.getResult();
1046}
1047
1048LogicalResult MakeDmaDescriptorOp::verify() {
1049 return verifyDescriptorOp(*this);
1050}
1051
1052OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1053 return foldDescriptorOp(*this, adaptor);
1054}
1055
1056//===----------------------------------------------------------------------===//
1057// MakeGatherDmaDescriptorOp
1058//===----------------------------------------------------------------------===//
1059
1060LogicalResult MakeGatherDmaDescriptorOp::verify() {
1061 ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
1062 size_t rank = globalStaticSizes.size();
1063 if (rank > 2)
1064 return emitOpError(
1065 "tensor and tile must be at most of rank two in gather mode.");
1067 Type elementType = cast<VectorType>(indices.getType()).getElementType();
1068 if (elementType != getBase().getType().getIndexType())
1069 return emitOpError("indices' element type must match base's element type.");
1070
1071 return verifyDescriptorOp(*this);
1072}
1073
1074OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1075 return foldDescriptorOp(*this, adaptor);
1076}
1077
1078//===----------------------------------------------------------------------===//
1079// ScaledMFMAOp
1080//===----------------------------------------------------------------------===//
1081
1082namespace {
1083/// Check if the scales input is used in other scaled mfma's while they exist.
1084/// If theyre unused then pack the scales.
1085struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
1087
1088 LogicalResult matchAndRewrite(ScaledMFMAOp op,
1089 PatternRewriter &rewriter) const override {
1090 Location loc = op.getLoc();
1091 auto setOpsel = [&op](unsigned idx, int64_t val) {
1092 switch (idx) {
1093 case 3:
1094 op.setScalesIdxA(val);
1095 break;
1096 case 4:
1097 op.setScalesIdxB(val);
1098 break;
1099 default:
1100 break;
1101 }
1102 };
1103
1104 // For every scale operand of this ScaledMFMAOp, if the scale is produced by
1105 // the extraction of a single scale from some vector, then attempt to
1106 // extract 4 values from that vector instead.
1107 //
1108 // Example: (f8 here means f8E8M0FNU)
1109 // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
1110 // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
1111 // amdgpu.scaled_mfma(%scale[0] * ...
1112 //
1113 // rewrite to:
1114 //
1115 // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
1116 // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
1117 // amdgpu.scaled_mfma(%scale[0-3] * ...
1118 //
1119 // This creates duplicate shape_casts for every use but these will be
1120 // removed in CSE.
1121 for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
1122 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
1123 if (!insertOp) {
1124 return rewriter.notifyMatchFailure(op,
1125 "defining op not a vector.insert");
1126 }
1127 // If the extracted value is not a single scalar, then it has been packed.
1128 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
1129 return rewriter.notifyMatchFailure(
1130 op, "scaled mfma operand already packed");
1131 }
1132
1133 auto extractOp =
1134 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
1135 if (!extractOp) {
1136 return rewriter.notifyMatchFailure(op,
1137 "defining op not a vector.extract");
1138 }
1139
1140 Value scaleSrc = extractOp.getOperand(0);
1141 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
1142 if (!scaleSrcType) {
1143 return rewriter.notifyMatchFailure(op, "not a vector type");
1144 }
1145
1146 // We do not handle dynamic dims yet, assume that the input is padded to
1147 // a static shape now.
1148 if (!scaleSrcType.hasStaticShape()) {
1149 return rewriter.notifyMatchFailure(op,
1150 "dynamic dims not yet supported");
1151 }
1152
1153 int64_t numElements = scaleSrcType.getNumElements();
1154 if (numElements <= 4) {
1155 return rewriter.notifyMatchFailure(
1156 op, "no packing if # of scales less than four");
1157 }
1158
1159 // Find a linearized idx using the size and offsets of the extract op.
1160 auto extractedPos = llvm::to_vector_of<int64_t>(
1161 llvm::reverse(extractOp.getStaticPosition()));
1162 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
1163 int64_t scaleSrcRank = scaleSrcType.getRank();
1164 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
1165 for (int64_t i = 1; i < scaleSrcRank; ++i) {
1166 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
1167 }
1168 int64_t idx = linearize(extractedPos, extractSizes);
1169
1170 // All n scales (where n is the total number of scales) must now be
1171 // extracted in chunks of 4 elements. This is done by dividing the
1172 // original vector of scales into groups of 4 elements
1173 // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
1174 // scale at a particular index are now replaced with an extraction
1175 // of the entire group of 4 elements to which that index belongs.
1176 //
1177 // If the number of scales happens to be indivisible by 4, extract
1178 // the remaining n - m scales in a chunk of 4 elements starting at
1179 // offset n - 4.
1180 int64_t offset = idx - (idx % 4);
1181 int64_t opsel = idx - offset;
1182 int64_t size = 4l;
1183 // Accomdate remaining elements in the case of non-4-divisible vectors.
1184 if (numElements - offset < size) {
1185 opsel = size - (numElements - idx);
1186 offset = numElements - 4l;
1187 }
1188 Type scaleSrcElemType = scaleSrcType.getElementType();
1189 auto newSrcType =
1190 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1191 Value newScaleSrc =
1192 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1193 auto extract = vector::ExtractStridedSliceOp::create(
1194 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1195 ArrayRef{int64_t(1)});
1196 rewriter.modifyOpInPlace(op, [&] {
1197 op->setOperand(opIdx, extract);
1198 setOpsel(opIdx, opsel);
1199 });
1200 }
1201 return success();
1202 }
1203};
1204} // namespace
1205
1206void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
1207 MLIRContext *context) {
1208 results.add<PackScales>(context);
1209}
1210
1211#define GET_OP_CLASSES
1212#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyDescriptorOp(DescriptorOp op)
static LogicalResult verifyRawBufferOp(T &op)
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
Definition AMDGPUOps.cpp:73
static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, PatternRewriter &rewriter)
Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:45
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:70
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...