MLIR 23.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1//===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMDGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/IR/Builders.h"
24#include "mlir/IR/Diagnostics.h"
26#include "mlir/IR/Matchers.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34
35#include <algorithm>
36#include <cstdint>
37#include <limits>
38#include <optional>
39
40using namespace mlir;
41using namespace mlir::amdgpu;
42
43#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
44
45namespace {
46struct AMDGPUInlinerInterface final : DialectInlinerInterface {
47 using DialectInlinerInterface::DialectInlinerInterface;
48 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
49 return true;
50 }
51};
52} // namespace
53
54void AMDGPUDialect::initialize() {
55 addOperations<
56#define GET_OP_LIST
57#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
58 >();
59 addTypes<
60#define GET_TYPEDEF_LIST
61#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
62 >();
63 addAttributes<
64#define GET_ATTRDEF_LIST
65#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
66 >();
67 addInterfaces<AMDGPUInlinerInterface>();
68}
69
70//===----------------------------------------------------------------------===//
71// 8-bit float ops
72//===----------------------------------------------------------------------===//
73LogicalResult PackedTrunc2xFp8Op::verify() {
74 if (getExisting() && getExisting().getType() != getResult().getType())
75 return emitOpError("existing values must have same type as result");
76 return success();
77}
78
79LogicalResult PackedStochRoundFp8Op::verify() {
80 if (getExisting() && getExisting().getType() != getResult().getType())
81 return emitOpError("existing values must have same type as result");
82 return success();
83}
84
85//===----------------------------------------------------------------------===//
86// mxfp float ops
87//===----------------------------------------------------------------------===//
88LogicalResult PackedScaledTruncOp::verify() {
89 if (getExisting() && getExisting().getType() != getResult().getType())
90 return emitOpError("existing values must have same type as result");
91 return success();
92}
93
94//===----------------------------------------------------------------------===//
95// FatRawBufferCastOp
96//===----------------------------------------------------------------------===//
97
98/// Convert the type `source` to one with the same sizes and strides - and
99/// offset, unless `stripOffset` is true, in which case the offset is reset to
100/// 0, if the offset should be reset but the layout of `source` isn't either the
101/// identity layout or a strided layout, this function fails.
102static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
103 bool resetOffset) {
104 MLIRContext *ctx = source.getContext();
105 MemRefType::Builder mb(source);
107 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
108 MemRefLayoutAttrInterface layout = source.getLayout();
109 if (resetOffset && !layout.isIdentity()) {
110 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
111 if (!stridedLayout)
112 return failure();
113 MemRefLayoutAttrInterface newLayout =
114 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
115 // Special case: if resetting the offset causes the strided layout to become
116 // the identity layout, then reset to the identity layout.
117 // TODO: this'll get a lot simpler when we have the contiguous layout.
118 SmallVector<int64_t> stridesIfIdentity;
119 if (source.hasStaticShape()) {
120 stridesIfIdentity = computeSuffixProduct(source.getShape());
121 } else if (source.getRank() <= 1) {
122 stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
123 }
124 if (stridesIfIdentity == stridedLayout.getStrides()) {
125 newLayout = AffineMapAttr::get(
126 AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
127 }
128 mb.setLayout(newLayout);
129 }
130 return (MemRefType)(mb);
131}
132
133LogicalResult FatRawBufferCastOp::inferReturnTypes(
134 MLIRContext *context, std::optional<Location> location, ValueRange operands,
135 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
136 SmallVectorImpl<Type> &inferredReturnTypes) {
137 Adaptor adaptor(operands, attributes, properties, regions);
138 auto sourceType =
139 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
140 if (!sourceType)
141 return failure();
142 FailureOr<MemRefType> resultType =
143 getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
144 if (failed(resultType))
145 return failure();
146 inferredReturnTypes = SmallVector<Type>{*resultType};
147 return success();
148}
149
150FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(OpBuilder &builder,
151 int resultIndex,
152 int dim) {
153 assert(resultIndex == 0 && "FatRawBufferCastOp has a single result");
154 return memref::getMixedSize(builder, getLoc(), getSource(), dim);
155}
156
157LogicalResult FatRawBufferCastOp::verify() {
158 FailureOr<MemRefType> expectedResultType =
159 getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
160 if (failed(expectedResultType))
161 return emitOpError("source type ")
162 << getSource().getType() << " can't have its offset reset";
163 if (getResult().getType() != *expectedResultType)
164 return emitOpError("expected result type to be ")
165 << *expectedResultType << " but got " << getResult().getType();
166 return success();
167}
168
169static bool hasGlobalMemorySpace(Attribute memorySpace) {
170 if (!memorySpace)
171 return true;
172 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
173 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
174 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
175 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
176 return false;
177}
178
179static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
180 if (!memorySpace)
181 return false;
182 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
183 return intMemorySpace.getInt() == 3;
184 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
185 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
186 return false;
187}
188
189static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
190 if (!memorySpace)
191 return false;
192 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
193 return intMemorySpace.getInt() == 7;
194 if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
195 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
196 return false;
197}
198
199//===----------------------------------------------------------------------===//
200// RawBuffer*Op
201//===----------------------------------------------------------------------===//
202template <typename T>
203static LogicalResult verifyRawBufferOp(T &op) {
204 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
205 bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
206
207 if (!isGlobal)
208 return op.emitOpError(
209 "Buffer ops must operate on a memref in global memory");
210 if (!bufferType.hasRank())
211 return op.emitOpError(
212 "Cannot meaningfully buffer_store to an unranked memref");
213 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
214 return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
215 " indices to memref");
216 return success();
217}
218
219LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
220
221LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
222
223LogicalResult RawBufferAtomicFaddOp::verify() {
224 return verifyRawBufferOp(*this);
225}
226
227LogicalResult RawBufferAtomicFmaxOp::verify() {
228 return verifyRawBufferOp(*this);
229}
230
231LogicalResult RawBufferAtomicSmaxOp::verify() {
232 return verifyRawBufferOp(*this);
233}
234
235LogicalResult RawBufferAtomicUminOp::verify() {
236 return verifyRawBufferOp(*this);
237}
238
239LogicalResult RawBufferAtomicCmpswapOp::verify() {
240 return verifyRawBufferOp(*this);
241}
242
243static std::optional<uint32_t> getConstantUint32(Value v) {
244 APInt cst;
245 if (!v.getType().isInteger(32))
246 return std::nullopt;
247 if (matchPattern(v, m_ConstantInt(&cst)))
248 return cst.getZExtValue();
249 return std::nullopt;
250}
251
252template <typename OpType>
253static bool staticallyOutOfBounds(OpType op) {
254 if (!op.getBoundsCheck())
255 return false;
256 MemRefType bufferType = op.getMemref().getType();
257 if (!bufferType.hasStaticShape())
258 return false;
259 int64_t offset;
260 SmallVector<int64_t> strides;
261 if (failed(bufferType.getStridesAndOffset(strides, offset)))
262 return false;
263 int64_t result = offset + op.getIndexOffset().value_or(0);
264 if (op.getSgprOffset()) {
265 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
266 if (!sgprOffset)
267 return false;
268 result += *sgprOffset;
269 }
270 if (strides.size() != op.getIndices().size())
271 return false;
272 int64_t indexVal = 0;
273 for (auto pair : llvm::zip(strides, op.getIndices())) {
274 int64_t stride = std::get<0>(pair);
275 Value idx = std::get<1>(pair);
276 std::optional<uint32_t> idxVal = getConstantUint32(idx);
277 if (!idxVal)
278 return false;
279 indexVal += stride * *idxVal;
280 }
281 result += indexVal;
282 if (result > std::numeric_limits<uint32_t>::max())
283 // Overflow means don't drop
284 return false;
285 return result >= bufferType.getNumElements();
286}
287
288namespace {
289template <typename OpType>
290struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
291 using OpRewritePattern<OpType>::OpRewritePattern;
292
293 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
294 if (!staticallyOutOfBounds(op))
295 return failure();
296 Type loadType = op.getResult().getType();
297 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
298 rw.getZeroAttr(loadType));
299 return success();
300 }
301};
302
303template <typename OpType>
304struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
305 using OpRewritePattern<OpType>::OpRewritePattern;
306
307 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
308 if (!staticallyOutOfBounds(op))
309 return failure();
310
311 rw.eraseOp(op);
312 return success();
313 }
314};
315} // end namespace
316
317void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
318 MLIRContext *context) {
319 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
320}
321
322void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
323 MLIRContext *context) {
324 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
325}
326
327void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
328 RewritePatternSet &results, MLIRContext *context) {
329 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
330}
331
332void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
333 RewritePatternSet &results, MLIRContext *context) {
334 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
335}
336
337void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
338 RewritePatternSet &results, MLIRContext *context) {
339 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
340}
341
342void RawBufferAtomicUminOp::getCanonicalizationPatterns(
343 RewritePatternSet &results, MLIRContext *context) {
344 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
345}
346
347void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
348 RewritePatternSet &results, MLIRContext *context) {
349 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
350 context);
351}
352
353//===----------------------------------------------------------------------===//
354// ScaledExtPackedMatrixOp
355//===----------------------------------------------------------------------===//
356LogicalResult ScaledExtPackedMatrixOp::verify() {
357 int blockSize = getBlockSize();
358 assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
359
360 int firstScaleByte = getFirstScaleByte();
361 int firstScaleLane = getFirstScaleLane();
362 auto sourceType = cast<VectorType>(getSource().getType());
363 Type elementType = sourceType.getElementType();
364 auto floatType = cast<FloatType>(elementType);
365 unsigned bitWidth = floatType.getWidth();
366
367 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
368
369 const bool is_fp8 = bitWidth == 8;
370 const bool is_block_16 = blockSize == 16;
371
372 if (!is_fp8) {
373 if (is_block_16) {
374 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
375 return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
376 "or 1 for f4 and f6.");
377 }
378 } else {
379 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
380 return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
381 "or 2 for f4 and f6.");
382 }
383 }
384 } else {
385 if (is_block_16) {
386 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
387 ((firstScaleLane == 16) && (firstScaleByte == 2));
388 if (!is_valid) {
389 return emitOpError("blockSize of 16 can only have (firstScaleLane, "
390 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
391 }
392 }
393 }
394
395 return success();
396}
397
398//===----------------------------------------------------------------------===//
399// WMMAOp
400//===----------------------------------------------------------------------===//
401
403 IntegerAttr &m, IntegerAttr &n,
404 IntegerAttr &k) {
405 SmallVector<int64_t, 3> dimensions;
406 if (parser.parseDimensionList(dimensions, false, false))
407 return failure();
408 if (dimensions.size() != 3)
409 return parser.emitError(parser.getCurrentLocation())
410 << "expected 3 dimensions in MNK dimension list";
411
412 m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
413 n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
414 k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
415 return success();
416}
417
418LogicalResult WMMAOp::verify() {
419 auto sourceAType = cast<VectorType>(getSourceA().getType());
420 auto sourceBType = cast<VectorType>(getSourceB().getType());
421 auto destType = cast<VectorType>(getDestC().getType());
422
423 Type sourceAElemType = sourceAType.getElementType();
424 Type sourceBElemType = sourceBType.getElementType();
425 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
426 return emitOpError("source vectors have different lengths: ")
427 << sourceAType << " vs. " << sourceBType;
428 }
429
430 bool isDestFloat = destType.getElementType().isFloat();
431 bool isSrcFloat = sourceAElemType.isFloat();
432
433 if (isDestFloat && !isSrcFloat)
434 return emitOpError("expected float sources with float destination");
435 if (!isDestFloat && isSrcFloat)
436 return emitOpError("expected int sources with int destination");
437
438 if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
439 return emitOpError(
440 "source element types must match (except for fp8/bf8) but have ")
441 << sourceAType << " and " << sourceBType;
442 }
443
444 if (isSrcFloat) {
445 if (getClamp())
446 return emitOpError("clamp flag is not supported for float types");
447 if (getUnsignedA() || getUnsignedB())
448 return emitOpError("unsigned flags are not supported for float types");
449 }
450 return success();
451}
452
453//===----------------------------------------------------------------------===//
454// ScaledWMMAOp
455//===----------------------------------------------------------------------===//
456
457LogicalResult ScaledWMMAOp::verify() {
458 // Helper functions for type classification.
459 auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
460 auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
461 auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
462 auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
463 auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
464 auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
465
466 auto sourceAType = cast<VectorType>(getSourceA().getType());
467 auto sourceBType = cast<VectorType>(getSourceB().getType());
468 auto destType = cast<VectorType>(getDestC().getType());
469
470 // Validate source element types are small floats (fp4/fp6/fp8).
471 Type aElemType = sourceAType.getElementType();
472 Type bElemType = sourceBType.getElementType();
473
474 // Validate vector lengths based on dimensions.
475 int64_t m = getM();
476 int64_t aLen = sourceAType.getNumElements();
477 int64_t bLen = sourceBType.getNumElements();
478 int64_t expectedOutLen = (m == 16) ? 8 : 16;
479
480 if (destType.getNumElements() != expectedOutLen)
481 return emitOpError("expected output vector of length ")
482 << expectedOutLen << " but got " << destType.getNumElements();
483
484 if (m == 16) {
485 // For 16×16×128: both A and B must be 64 elements.
486 if (aLen != 64)
487 return emitOpError(
488 "for 16x16x128, sourceA must have 64 elements but got ")
489 << aLen;
490 if (bLen != 64)
491 return emitOpError(
492 "for 16x16x128, sourceB must have 64 elements but got ")
493 << bLen;
494 } else { // m == 32
495 // For 32×16×128: only fp4 is supported, A is 128, B is 64.
496 if (!isF4(aElemType) && !isF4(bElemType))
497 return emitOpError("32x16x128 only supports fp4 element types");
498
499 if (aLen != 128)
500 return emitOpError(
501 "for 32x16x128, sourceA must have 128 elements but got ")
502 << aLen;
503 if (bLen != 64)
504 return emitOpError(
505 "for 32x16x128, sourceB must have 64 elements but got ")
506 << bLen;
507
508 // For 32x16x128, matrix A uses all 32 lanes so a_first_scale_lane must be
509 // 0.
510 if (getAFirstScaleLane() != 0)
511 return emitOpError("for 32x16x128, a_first_scale_lane must be 0");
512 }
513
514 // Validate scale types and their compatibility with matrix element types.
515 auto scaleAType = cast<VectorType>(getScaleA().getType());
516 auto scaleBType = cast<VectorType>(getScaleB().getType());
517 Type scaleAElemType = scaleAType.getElementType();
518 Type scaleBElemType = scaleBType.getElementType();
519
520 // Validate scale element types are valid scale f8 types (E8M0FNU or E4M3FN).
521 if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
522 return emitOpError(
523 "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
524
525 // Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid.
526 if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
527 return success();
528
529 // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M3|E4M3).
530 if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
531 isF4(bElemType) && isE4M3(scaleBElemType))
532 return success();
533
534 // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M3|E4M3), Scale B (E8M0).
535 if (isF4(aElemType) && isE4M3(scaleAElemType) &&
536 (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
537 return success();
538
539 // Matrix A (F4) x Matrix B (F4) with Scale A (E4M3), Scale B (E4M3).
540 if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
541 isE4M3(scaleBElemType))
542 return success();
543
544 // No valid combination matched.
545 return emitOpError("invalid combination of matrix and scale types: ")
546 << "sourceA=" << aElemType << ", scaleA=" << scaleAElemType
547 << ", sourceB=" << bElemType << ", scaleB=" << scaleBElemType;
548}
549
550//===----------------------------------------------------------------------===//
551// MFMAOp
552//===----------------------------------------------------------------------===//
553LogicalResult MFMAOp::verify() {
554 constexpr uint32_t waveSize = 64;
556
557 Type sourceType = getSourceA().getType();
558 Type destType = getDestC().getType();
559
560 Type sourceElem = sourceType, destElem = destType;
561 uint32_t sourceLen = 1, destLen = 1;
562 if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
563 sourceLen = sourceVector.getNumElements();
564 sourceElem = sourceVector.getElementType();
565 }
566 if (auto destVector = dyn_cast<VectorType>(destType)) {
567 destLen = destVector.getNumElements();
568 destElem = destVector.getElementType();
569 }
570
571 Type sourceBType = getSourceB().getType();
572 if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
573 int64_t sourceBLen = 1;
574 Type sourceBElem = sourceBType;
575 if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
576 sourceBLen = sourceBVector.getNumElements();
577 sourceBElem = sourceBVector.getElementType();
578 }
579 if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
580 !sourceBElem.isFloat(4))
581 return emitOpError("expected both source operands to have small-float "
582 "elements if one does");
583 if (sourceLen != sourceBLen)
584 return emitOpError(
585 "expected both small-float source vectors to have the same length");
586 } else {
587 if (sourceType != sourceBType)
588 return emitOpError("expected both non-small-float source operand types "
589 "to match exactly");
590 }
591 // Normalize the wider integer types the compiler expects to i8.
592 if (sourceElem.isInteger(32)) {
593 sourceLen *= 4;
594 sourceElem = b.getI8Type();
595 }
596 if (sourceElem.isInteger(64)) {
597 sourceLen *= 8;
598 sourceElem = b.getI8Type();
599 }
600
601 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
602 if (sourceLen != numSourceElems)
603 return emitOpError("expected " + Twine(numSourceElems) +
604 " source values for this operation but got " +
605 Twine(sourceLen));
606
607 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
608 if (destLen != numDestElems)
609 return emitOpError("expected " + Twine(numDestElems) +
610 " result values for this operation but got " +
611 Twine(destLen));
612
613 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
614 return emitOpError(
615 "double-precision ops do not support permuting lanes of B");
616 if (destElem.isF64() && getCbsz() != 0)
617 return emitOpError(
618 "double-precision ops do not support permuting lanes of A");
619 if (getAbid() >= (1u << getCbsz()))
620 return emitOpError(
621 "block ID for permuting A (abid) must be below 2 ** cbsz");
622
623 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
624 return emitOpError(
625 "negation flags only available for double-precision operations");
626
627 return success();
628}
629
630//===----------------------------------------------------------------------===//
631// SparseMFMAOp
632//===----------------------------------------------------------------------===//
633
634LogicalResult SparseMFMAOp::verify() {
635 constexpr uint32_t waveSize = 64;
636
637 auto sparseType = cast<VectorType>(getSourceA().getType());
638 auto denseType = cast<VectorType>(getSourceB().getType());
639 auto destType = cast<VectorType>(getDestC().getType());
640
641 Type sparseElem = sparseType.getElementType();
642 Type denseElem = denseType.getElementType();
643 int64_t sparseLen = sparseType.getNumElements();
644 int64_t denseLen = denseType.getNumElements();
645 int64_t destLen = destType.getNumElements();
646
647 if (denseLen != 2 * sparseLen)
648 return emitOpError("expected dense source operand to have exactly double "
649 "the number of elements of the sparse source operand");
650
651 // Check that source element types are compatible.
652 // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8).
653 // For other types, element types must match exactly.
654 bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
655 if (!bothFloat8 && sparseElem != denseElem)
656 return emitOpError(
657 "expected source operands to have the same element type");
658
659 // When CBSZ == 0, ABID selects the index set within the sparse index VGPR.
660 // When CBSZ != 0, the first index set is always used (ABID ignored).
661 bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
662 // 8-bit source: ABID selects one of two 16-bit index sets.
663 if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
664 return emitOpError("ABID must be 0 or 1 for 8-bit source data");
665 // 16-bit source: ABID selects one of four 8-bit index sets (0-3 all valid).
666 if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
667 return emitOpError("ABID must be between 0 and 3 for 16-bit source data");
668
669 // Validate sparseIdx type matches source element type.
670 auto sparseIdxType = cast<VectorType>(getSparseIdx().getType());
671 if (is8BitSource) {
672 // 8-bit source data requires vector<2xi16> sparse indices.
673 if (sparseIdxType.getNumElements() != 2 ||
674 !sparseIdxType.getElementType().isInteger(16))
675 return emitOpError("expected vector<2xi16> sparse indices for 8-bit "
676 "source data, but got ")
677 << getSparseIdx().getType();
678 } else {
679 // 16-bit source data requires vector<4xi8> sparse indices.
680 if (sparseIdxType.getNumElements() != 4 ||
681 !sparseIdxType.getElementType().isInteger(8))
682 return emitOpError("expected vector<4xi8> sparse indices for 16-bit "
683 "source data, but got ")
684 << getSparseIdx().getType();
685 }
686
687 int64_t expectedSourceElems = (getM() * getK()) / waveSize;
688 if (denseLen != expectedSourceElems)
689 return emitOpError("expected " + Twine(expectedSourceElems) +
690 " source values for this operation but got " +
691 Twine(denseLen));
692
693 int64_t expectedDestElems = (getM() * getN()) / waveSize;
694 if (destLen != expectedDestElems)
695 return emitOpError("expected " + Twine(expectedDestElems) +
696 " result values for this operation but got " +
697 Twine(destLen));
698
699 return success();
700}
701
702//===----------------------------------------------------------------------===//
703// DPPOp
704//===----------------------------------------------------------------------===//
705LogicalResult DPPOp::verify() {
706 Type srcType = getSrc().getType();
707 if (srcType.getIntOrFloatBitWidth() > 64) {
708 return emitOpError("integer and floating point types larger than 64 bits "
709 "are not supported");
710 }
711
712 DPPPerm kind = getKind();
713 Attribute permArgument = getPermArgument().value_or(Attribute{});
714
715 switch (kind) {
716
717 case DPPPerm::quad_perm: {
718 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
719 if (!quadPermAttr || quadPermAttr.size() != 4) {
720 return emitOpError("quad_perm attribute must have exactly 4 elements");
721 }
722 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
723 int32_t num = elem.getInt();
724 if (num < 0 || num > 3) {
725 return emitOpError(
726 "Each element of quad_perm must be in the range [0, 3]");
727 }
728 }
729 } break;
730
731 case DPPPerm::row_shl:
732 case DPPPerm::row_shr:
733 case DPPPerm::row_ror: {
734 if (!permArgument) {
735 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
736 "' value not specified");
737 }
738 if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
739 uint32_t attrValue = intAttr.getInt();
740 if (attrValue < 1 || attrValue > 15) {
741 return emitOpError("Attribute value must be between 1 and 15");
742 }
743 }
744 } break;
745
746 case DPPPerm::wave_shl:
747 case DPPPerm::wave_shr:
748 case DPPPerm::wave_rol:
749 case DPPPerm::wave_ror:
750 case DPPPerm::row_mirror:
751 case DPPPerm::row_half_mirror:
752 case DPPPerm::row_bcast_15:
753 case DPPPerm::row_bcast_31: {
754 if (permArgument && !isa<UnitAttr>(permArgument)) {
755 return emitOpError("Expected unit attribute for permArgument, but found "
756 "non-trivial argument");
757 }
758 break;
759 }
760 }
761 return success();
762}
763
764//===----------------------------------------------------------------------===//
765// PermlaneSwapOp
766//===----------------------------------------------------------------------===//
767LogicalResult PermlaneSwapOp::verify() {
768 unsigned rowLength = getRowLength();
769
770 if (rowLength != 16 && rowLength != 32)
771 return emitOpError("row_length attribute must either be 16 or 32.");
772
773 return success();
774}
775
776/// Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
777static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op,
778 PatternRewriter &rewriter) {
779 if (isa_and_nonnull<LDSBarrierOp>(op->getNextNode())) {
780 rewriter.eraseOp(op);
781 return success();
782 }
783 return failure();
784}
785
786void LDSBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
787 MLIRContext *context) {
789}
790
791//===----------------------------------------------------------------------===//
792// MemoryCounterWaitOp
793//===----------------------------------------------------------------------===//
794
795namespace {
796/// Fuse adjacent memory counter wait ops, taking the minimum value of the
797/// counters.
798struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
799 using Base::Base;
800
801 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
802 PatternRewriter &rewriter) const override {
803 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
804 if (!next)
805 return failure();
806
807 auto setters = {&MemoryCounterWaitOp::setLoad,
808 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
809 &MemoryCounterWaitOp::setExp,
810 &MemoryCounterWaitOp::setTensor};
811 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
812 op.getTensor()};
813 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
814 next.getExp(), next.getTensor()};
815 rewriter.modifyOpInPlace(op, [&] {
816 for (auto [setter, lhs, rhs] :
817 llvm::zip_equal(setters, lhsVals, rhsVals)) {
818 if (lhs && rhs) {
819 (op.*setter)(std::min(*lhs, *rhs));
820 } else if (lhs) {
821 (op.*setter)(*lhs);
822 } else if (rhs) {
823 (op.*setter)(*rhs);
824 }
825 }
826 });
827 rewriter.eraseOp(next);
828 return success();
829 }
830};
831} // namespace
832
833void MemoryCounterWaitOp::getCanonicalizationPatterns(
834 RewritePatternSet &results, MLIRContext *context) {
835 results.add<FuseMemoryCounterWaitOp>(context);
836}
837
838//===----------------------------------------------------------------------===//
839// GatherToLDSOp
840//===----------------------------------------------------------------------===//
841
842LogicalResult GatherToLDSOp::verify() {
843 MemRefType srcType = cast<MemRefType>(getSrc().getType());
844 MemRefType dstType = cast<MemRefType>(getDst().getType());
845
846 if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
847 return emitOpError("destination type inner most dim must be contiguous");
848
849 auto elemType = srcType.getElementType();
850 // Check $src and $dst element types are the same.
851 if (elemType != dstType.getElementType())
852 return emitOpError("source and destination element types must match");
853
854 // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
855 auto transferType = getTransferType();
856 int transferSize;
857 if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
858 transferSize = vectorTransfer.getNumElements() *
859 vectorTransfer.getElementTypeBitWidth();
860 } else {
861 transferSize = transferType.getIntOrFloatBitWidth();
862 }
863 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
864 return emitOpError(
865 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
866
867 if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
868 !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
869 return emitOpError(
870 "source memory address space must be global or fat raw buffer");
871
872 if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
873 return emitOpError("destination memory address space must be Workgroup");
874
875 return success();
876}
877
878namespace {
879/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
880/// information or changes layout, the cast can be skipped.
881struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
883
884 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
885 PatternRewriter &rewriter) const override {
886 bool modified = false;
887 auto foldCast = [&](OpOperand &operand) {
888 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
889 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
890 rewriter.modifyOpInPlace(gatherOp,
891 [&] { operand.assign(castOp.getSource()); });
892 modified = true;
893 }
894 }
895 };
896
897 foldCast(gatherOp.getSrcMutable());
898 foldCast(gatherOp.getDstMutable());
899
900 return success(modified);
901 }
902};
903} // namespace
904
905void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
906 MLIRContext *context) {
907 results.add<FoldGatherToLDSOfCast>(context);
908}
909
910//===----------------------------------------------------------------------===//
911// TransposeLoadOp
912//===----------------------------------------------------------------------===//
913
914LogicalResult TransposeLoadOp::verify() {
915 MemRefType srcType = cast<MemRefType>(getSrc().getType());
916
917 if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
918 return emitOpError("source memory address space must be Workgroup");
919
920 auto transferType = cast<VectorType>(getType());
921 size_t numElements = transferType.getNumElements();
922 size_t elementTypeSize =
923 transferType.getElementType().getIntOrFloatBitWidth();
924
925 // ElementSize -> NumElements
926 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
927 {4, 16},
928 {6, 16},
929 {8, 8},
930 {16, 4},
931 };
932
933 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
934 if (validNumElems == kValidLoadSizeMap.end())
935 return emitOpError("Unsupported element type size for transpose load: ")
936 << elementTypeSize << " bits";
937
938 if (numElements != validNumElems->second)
939 return emitOpError(
940 "Transferring type size mismatch: expected num of elements: ")
941 << validNumElems->second;
942
943 return success();
944}
945
946//===----------------------------------------------------------------------===//
947// MakeDmaBaseOp
948//===----------------------------------------------------------------------===//
949
950template <typename BaseOp>
951static LogicalResult verifyBase(BaseOp op) {
952 auto ldsType = cast<MemRefType>(op.getLds().getType());
953 auto globalType = cast<MemRefType>(op.getGlobal().getType());
954 if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
955 return op.emitOpError(
956 "lds memref must have workgroup address space attribute.");
957 if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
958 return op.emitOpError(
959 "global memref must have global address space attribute.");
960
961 Type elementType = ldsType.getElementType();
962 unsigned width = elementType.getIntOrFloatBitWidth();
963
964 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
965 return op.emitOpError(
966 "element type must be 1, 2, 4, or 8 bytes long but type was ")
967 << width << " bits long.";
968 return success();
969}
970
971LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
972
973//===----------------------------------------------------------------------===//
974// MakeGatherDmaBaseOp
975//===----------------------------------------------------------------------===//
976
977LogicalResult
978TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
979 Type elementType, Type indexType) {
980 unsigned width = elementType.getIntOrFloatBitWidth();
981 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
982 return emitError()
983 << "element type must be 1, 2, 4, or 8 bytes wide but type "
984 << elementType << " is " << width / 8 << " bytes wide.";
985 MLIRContext *ctx = elementType.getContext();
986 Type i16 = IntegerType::get(ctx, 32);
987 Type i32 = IntegerType::get(ctx, 16);
988 if (!llvm::is_contained({i16, i32}, indexType))
989 return emitError() << "index type must be i16 or i32 but index type is "
990 << indexType << ".";
991 return success();
992}
993
994LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
995
996//===----------------------------------------------------------------------===//
997// MakeDmaDescriptorOp
998//===----------------------------------------------------------------------===//
999
1000template <typename DescriptorOp>
1001static LogicalResult verifyDescriptorOp(DescriptorOp op) {
1002 ArrayRef<int64_t> globalStaticStrides = op.getGlobalStaticStrides();
1003
1004 if (globalStaticStrides.empty())
1005 return op.emitOpError("strides must not be empty.");
1006 if (globalStaticStrides.back() != 1)
1007 return op.emitOpError("strides for the innermost dimension must be 1.");
1008
1009 ArrayRef<int64_t> globalStaticSizes = op.getGlobalStaticSizes();
1010 size_t rank = globalStaticSizes.size();
1011 if (rank > 5)
1012 return op.emitOpError("tensor and tile must be at most of rank 5.");
1013 if (rank != globalStaticStrides.size())
1014 return op.emitOpError("strides and sizes must have same rank.");
1015
1016 ArrayRef<int64_t> sharedStaticSizes = op.getSharedStaticSizes();
1017 if (rank != sharedStaticSizes.size())
1018 return op.emitOpError("tensor must have same rank as tile.");
1019
1020 unsigned elementTypeWidth = op.getElementTypeWidth();
1021 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
1022 return op.emitOpError(
1023 "element type width must be 1, 2, 4 or 8 bytes, but was ")
1024 << elementTypeWidth << " bits long";
1025
1026 if (Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
1027 auto atomicBarrierAddressType =
1028 cast<MemRefType>(atomicBarrierAddress.getType());
1029 bool barrierInLDS =
1030 hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace());
1031 if (!barrierInLDS)
1032 return op.emitOpError("atomic barrier address must be in LDS.");
1033 }
1034
1035 if (op.getEarlyTimeout() && !op.getWorkgroupMask())
1036 return op.emitOpError(
1037 "early timeout does not apply when workgroup_mask is not set.");
1038 return success();
1039}
1040
1041template <typename DescriptorOp, typename FoldAdaptor>
1042static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor) {
1043 SmallVector<OpFoldResult> mixedGlobalSizes(op.getMixedGlobalSizes());
1044 SmallVector<OpFoldResult> mixedGlobalStrides(op.getMixedGlobalStrides());
1045 SmallVector<OpFoldResult> mixedSharedSizes(op.getMixedSharedSizes());
1046
1047 if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true,
1048 /*onlyNonZero=*/true)) &&
1049 failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true,
1050 /*onlyNonZero=*/true)) &&
1051 failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true,
1052 /*onlyNonZero=*/true)))
1053 return nullptr;
1054
1055 SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides,
1056 dynamicSharedSizes;
1057 SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides,
1058 staticSharedSizes;
1059
1060 dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes,
1061 staticGlobalSizes);
1062 op.setGlobalStaticSizes(staticGlobalSizes);
1063 op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
1064
1065 dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides,
1066 staticGlobalStrides);
1067 op.setGlobalStaticStrides(staticGlobalStrides);
1068 op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
1069
1070 dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes,
1071 staticSharedSizes);
1072 op.setSharedStaticSizes(staticSharedSizes);
1073 op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
1074 return op.getResult();
1075}
1076
1077LogicalResult MakeDmaDescriptorOp::verify() {
1078 return verifyDescriptorOp(*this);
1079}
1080
1081OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1082 return foldDescriptorOp(*this, adaptor);
1083}
1084
1085//===----------------------------------------------------------------------===//
1086// MakeGatherDmaDescriptorOp
1087//===----------------------------------------------------------------------===//
1088
1089LogicalResult MakeGatherDmaDescriptorOp::verify() {
1090 ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
1091 size_t rank = globalStaticSizes.size();
1092 if (rank > 2)
1093 return emitOpError(
1094 "tensor and tile must be at most of rank two in gather mode.");
1096 Type elementType = cast<VectorType>(indices.getType()).getElementType();
1097 if (elementType != getBase().getType().getIndexType())
1098 return emitOpError("indices' element type must match base's element type.");
1099
1100 return verifyDescriptorOp(*this);
1101}
1102
1103OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1104 return foldDescriptorOp(*this, adaptor);
1105}
1106
1107//===----------------------------------------------------------------------===//
1108// ScaledMFMAOp
1109//===----------------------------------------------------------------------===//
1110
1111namespace {
1112/// Check if the scales input is used in other scaled mfma's while they exist.
1113/// If theyre unused then pack the scales.
1114struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
1116
1117 LogicalResult matchAndRewrite(ScaledMFMAOp op,
1118 PatternRewriter &rewriter) const override {
1119 Location loc = op.getLoc();
1120 auto setOpsel = [&op](unsigned idx, int64_t val) {
1121 switch (idx) {
1122 case 3:
1123 op.setScalesIdxA(val);
1124 break;
1125 case 4:
1126 op.setScalesIdxB(val);
1127 break;
1128 default:
1129 break;
1130 }
1131 };
1132
1133 // For every scale operand of this ScaledMFMAOp, if the scale is produced by
1134 // the extraction of a single scale from some vector, then attempt to
1135 // extract 4 values from that vector instead.
1136 //
1137 // Example: (f8 here means f8E8M0FNU)
1138 // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
1139 // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
1140 // amdgpu.scaled_mfma(%scale[0] * ...
1141 //
1142 // rewrite to:
1143 //
1144 // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
1145 // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
1146 // amdgpu.scaled_mfma(%scale[0-3] * ...
1147 //
1148 // This creates duplicate shape_casts for every use but these will be
1149 // removed in CSE.
1150 for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
1151 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
1152 if (!insertOp) {
1153 return rewriter.notifyMatchFailure(op,
1154 "defining op not a vector.insert");
1155 }
1156 // If the extracted value is not a single scalar, then it has been packed.
1157 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
1158 return rewriter.notifyMatchFailure(
1159 op, "scaled mfma operand already packed");
1160 }
1161
1162 auto extractOp =
1163 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
1164 if (!extractOp) {
1165 return rewriter.notifyMatchFailure(op,
1166 "defining op not a vector.extract");
1167 }
1168
1169 Value scaleSrc = extractOp.getOperand(0);
1170 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
1171 if (!scaleSrcType) {
1172 return rewriter.notifyMatchFailure(op, "not a vector type");
1173 }
1174
1175 // We do not handle dynamic dims yet, assume that the input is padded to
1176 // a static shape now.
1177 if (!scaleSrcType.hasStaticShape()) {
1178 return rewriter.notifyMatchFailure(op,
1179 "dynamic dims not yet supported");
1180 }
1181
1182 int64_t numElements = scaleSrcType.getNumElements();
1183 if (numElements <= 4) {
1184 return rewriter.notifyMatchFailure(
1185 op, "no packing if # of scales less than four");
1186 }
1187
1188 // Find a linearized idx using the size and offsets of the extract op.
1189 auto extractedPos = llvm::to_vector_of<int64_t>(
1190 llvm::reverse(extractOp.getStaticPosition()));
1191 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
1192 int64_t scaleSrcRank = scaleSrcType.getRank();
1193 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
1194 for (int64_t i = 1; i < scaleSrcRank; ++i) {
1195 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
1196 }
1197 int64_t idx = linearize(extractedPos, extractSizes);
1198
1199 // All n scales (where n is the total number of scales) must now be
1200 // extracted in chunks of 4 elements. This is done by dividing the
1201 // original vector of scales into groups of 4 elements
1202 // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
1203 // scale at a particular index are now replaced with an extraction
1204 // of the entire group of 4 elements to which that index belongs.
1205 //
1206 // If the number of scales happens to be indivisible by 4, extract
1207 // the remaining n - m scales in a chunk of 4 elements starting at
1208 // offset n - 4.
1209 int64_t offset = idx - (idx % 4);
1210 int64_t opsel = idx - offset;
1211 int64_t size = 4l;
1212 // Accomdate remaining elements in the case of non-4-divisible vectors.
1213 if (numElements - offset < size) {
1214 opsel = size - (numElements - idx);
1215 offset = numElements - 4l;
1216 }
1217 Type scaleSrcElemType = scaleSrcType.getElementType();
1218 auto newSrcType =
1219 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1220 Value newScaleSrc =
1221 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1222 auto extract = vector::ExtractStridedSliceOp::create(
1223 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1224 ArrayRef{int64_t(1)});
1225 rewriter.modifyOpInPlace(op, [&] {
1226 op->setOperand(opIdx, extract);
1227 setOpsel(opIdx, opsel);
1228 });
1229 }
1230 return success();
1231 }
1232};
1233} // namespace
1234
1235void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
1236 MLIRContext *context) {
1237 results.add<PackScales>(context);
1238}
1239
1240#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
1241
1242#define GET_ATTRDEF_CLASSES
1243#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
1244
1245#define GET_TYPEDEF_CLASSES
1246#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
1247
1248#define GET_OP_CLASSES
1249#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyDescriptorOp(DescriptorOp op)
static LogicalResult verifyRawBufferOp(T &op)
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, PatternRewriter &rewriter)
Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:45
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:69
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...