MLIR 22.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1//===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMDGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/Diagnostics.h"
25#include "mlir/IR/Matchers.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34#include <algorithm>
35#include <cstdint>
36#include <limits>
37#include <optional>
38
39using namespace mlir;
40using namespace mlir::amdgpu;
41
42#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
43
44namespace {
45struct AMDGPUInlinerInterface final : DialectInlinerInterface {
47 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
48 return true;
49 }
50};
51} // namespace
52
53void AMDGPUDialect::initialize() {
54 addOperations<
55#define GET_OP_LIST
56#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
57 >();
58 addAttributes<
59#define GET_ATTRDEF_LIST
60#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
61 >();
62 addInterfaces<AMDGPUInlinerInterface>();
63}
64
65//===----------------------------------------------------------------------===//
66// 8-bit float ops
67//===----------------------------------------------------------------------===//
68LogicalResult PackedTrunc2xFp8Op::verify() {
69 if (getExisting() && getExisting().getType() != getResult().getType())
70 return emitOpError("existing values must have same type as result");
71 return success();
72}
73
74LogicalResult PackedStochRoundFp8Op::verify() {
75 if (getExisting() && getExisting().getType() != getResult().getType())
76 return emitOpError("existing values must have same type as result");
77 return success();
78}
79
80//===----------------------------------------------------------------------===//
81// mxfp float ops
82//===----------------------------------------------------------------------===//
83LogicalResult PackedScaledTruncOp::verify() {
84 if (getExisting() && getExisting().getType() != getResult().getType())
85 return emitOpError("existing values must have same type as result");
86 return success();
87}
88
89//===----------------------------------------------------------------------===//
90// FatRawBufferCastOp
91//===----------------------------------------------------------------------===//
92
93/// Convert the type `source` to one with the same sizes and strides - and
94/// offset, unless `stripOffset` is true, in which case the offset is reset to
95/// 0, if the offset should be reset but the layout of `source` isn't either the
96/// identity layout or a strided layout, this function fails.
97static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
98 bool resetOffset) {
99 MLIRContext *ctx = source.getContext();
100 MemRefType::Builder mb(source);
102 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
103 MemRefLayoutAttrInterface layout = source.getLayout();
104 if (resetOffset && !layout.isIdentity()) {
105 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
106 if (!stridedLayout)
107 return failure();
108 MemRefLayoutAttrInterface newLayout =
109 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
110 // Special case: if resetting the offset causes the strided layout to become
111 // the identity layout, then reset to the identity layout.
112 // TODO: this'll get a lot simpler when we have the contiguous layout.
113 SmallVector<int64_t> stridesIfIdentity;
114 if (source.hasStaticShape()) {
115 stridesIfIdentity = computeSuffixProduct(source.getShape());
116 } else if (source.getRank() <= 1) {
117 stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
118 }
119 if (stridesIfIdentity == stridedLayout.getStrides()) {
120 newLayout = AffineMapAttr::get(
121 AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
122 }
123 mb.setLayout(newLayout);
124 }
125 return (MemRefType)(mb);
126}
127
128LogicalResult FatRawBufferCastOp::inferReturnTypes(
129 MLIRContext *context, std::optional<Location> location, ValueRange operands,
130 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
131 SmallVectorImpl<Type> &inferredReturnTypes) {
132 Adaptor adaptor(operands, attributes, properties, regions);
133 auto sourceType =
134 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
135 if (!sourceType)
136 return failure();
137 FailureOr<MemRefType> resultType =
138 getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
139 if (failed(resultType))
140 return failure();
141 inferredReturnTypes = SmallVector<Type>{*resultType};
142 return success();
143}
144
145LogicalResult FatRawBufferCastOp::verify() {
146 FailureOr<MemRefType> expectedResultType =
147 getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
148 if (failed(expectedResultType))
149 return emitOpError("source type ")
150 << getSource().getType() << " can't have its offset reset";
151 if (getResult().getType() != *expectedResultType)
152 return emitOpError("expected result type to be ")
153 << *expectedResultType << " but got " << getResult().getType();
154 return success();
155}
156
157static bool hasGlobalMemorySpace(Attribute memorySpace) {
158 if (!memorySpace)
159 return true;
160 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
161 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
162 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
163 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
164 return false;
165}
166
167static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
168 if (!memorySpace)
169 return false;
170 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
171 return intMemorySpace.getInt() == 3;
172 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
173 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
174 return false;
175}
176
177static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
178 if (!memorySpace)
179 return false;
180 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
181 return intMemorySpace.getInt() == 7;
182 if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
183 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
184 return false;
185}
186
187//===----------------------------------------------------------------------===//
188// RawBuffer*Op
189//===----------------------------------------------------------------------===//
190template <typename T>
191static LogicalResult verifyRawBufferOp(T &op) {
192 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
193 bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
194
195 if (!isGlobal)
196 return op.emitOpError(
197 "Buffer ops must operate on a memref in global memory");
198 if (!bufferType.hasRank())
199 return op.emitOpError(
200 "Cannot meaningfully buffer_store to an unranked memref");
201 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
202 return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
203 " indices to memref");
204 return success();
205}
206
207LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
208
209LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
210
211LogicalResult RawBufferAtomicFaddOp::verify() {
212 return verifyRawBufferOp(*this);
213}
214
215LogicalResult RawBufferAtomicFmaxOp::verify() {
216 return verifyRawBufferOp(*this);
217}
218
219LogicalResult RawBufferAtomicSmaxOp::verify() {
220 return verifyRawBufferOp(*this);
221}
222
223LogicalResult RawBufferAtomicUminOp::verify() {
224 return verifyRawBufferOp(*this);
225}
226
227LogicalResult RawBufferAtomicCmpswapOp::verify() {
228 return verifyRawBufferOp(*this);
229}
230
231static std::optional<uint32_t> getConstantUint32(Value v) {
232 APInt cst;
233 if (!v.getType().isInteger(32))
234 return std::nullopt;
235 if (matchPattern(v, m_ConstantInt(&cst)))
236 return cst.getZExtValue();
237 return std::nullopt;
238}
239
240template <typename OpType>
241static bool staticallyOutOfBounds(OpType op) {
242 if (!op.getBoundsCheck())
243 return false;
244 MemRefType bufferType = op.getMemref().getType();
245 if (!bufferType.hasStaticShape())
246 return false;
247 int64_t offset;
248 SmallVector<int64_t> strides;
249 if (failed(bufferType.getStridesAndOffset(strides, offset)))
250 return false;
251 int64_t result = offset + op.getIndexOffset().value_or(0);
252 if (op.getSgprOffset()) {
253 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
254 if (!sgprOffset)
255 return false;
256 result += *sgprOffset;
257 }
258 if (strides.size() != op.getIndices().size())
259 return false;
260 int64_t indexVal = 0;
261 for (auto pair : llvm::zip(strides, op.getIndices())) {
262 int64_t stride = std::get<0>(pair);
263 Value idx = std::get<1>(pair);
264 std::optional<uint32_t> idxVal = getConstantUint32(idx);
265 if (!idxVal)
266 return false;
267 indexVal += stride * *idxVal;
268 }
269 result += indexVal;
270 if (result > std::numeric_limits<uint32_t>::max())
271 // Overflow means don't drop
272 return false;
273 return result >= bufferType.getNumElements();
274}
275
276namespace {
277template <typename OpType>
278struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
279 using OpRewritePattern<OpType>::OpRewritePattern;
280
281 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
282 if (!staticallyOutOfBounds(op))
283 return failure();
284 Type loadType = op.getResult().getType();
285 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
286 rw.getZeroAttr(loadType));
287 return success();
288 }
289};
290
291template <typename OpType>
292struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
293 using OpRewritePattern<OpType>::OpRewritePattern;
294
295 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
296 if (!staticallyOutOfBounds(op))
297 return failure();
298
299 rw.eraseOp(op);
300 return success();
301 }
302};
303} // end namespace
304
305void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
306 MLIRContext *context) {
307 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
308}
309
310void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
311 MLIRContext *context) {
312 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
313}
314
315void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
316 RewritePatternSet &results, MLIRContext *context) {
317 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
318}
319
320void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
321 RewritePatternSet &results, MLIRContext *context) {
322 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
323}
324
325void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
326 RewritePatternSet &results, MLIRContext *context) {
327 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
328}
329
330void RawBufferAtomicUminOp::getCanonicalizationPatterns(
331 RewritePatternSet &results, MLIRContext *context) {
332 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
333}
334
335void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
336 RewritePatternSet &results, MLIRContext *context) {
337 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
338 context);
339}
340
341//===----------------------------------------------------------------------===//
342// ScaledExtPacked816Op
343//===----------------------------------------------------------------------===//
344LogicalResult ScaledExtPacked816Op::verify() {
345 int blockSize = getBlockSize();
346 assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
347
348 int firstScaleByte = getFirstScaleByte();
349 int firstScaleLane = getFirstScaleLane();
350 auto sourceType = cast<VectorType>(getSource().getType());
351 Type elementType = sourceType.getElementType();
352 auto floatType = cast<FloatType>(elementType);
353 unsigned bitWidth = floatType.getWidth();
354
355 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
356
357 const bool is_fp8 = bitWidth == 8;
358 const bool is_block_16 = blockSize == 16;
359
360 if (!is_fp8) {
361 if (is_block_16) {
362 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
363 return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
364 "or 1 for f4 and f6.");
365 }
366 } else {
367 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
368 return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
369 "or 2 for f4 and f6.");
370 }
371 }
372 } else {
373 if (is_block_16) {
374 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
375 ((firstScaleLane == 1) && (firstScaleByte == 2));
376 if (!is_valid) {
377 return emitOpError("blockSize of 16 can only have (firstScaleLane, "
378 "firstScaleByte) be (0, 0) or (1, 2) for f8.");
379 }
380 }
381 }
382
383 return success();
384}
385
386//===----------------------------------------------------------------------===//
387// WMMAOp
388//===----------------------------------------------------------------------===//
389
391 IntegerAttr &m, IntegerAttr &n,
392 IntegerAttr &k) {
393 SmallVector<int64_t, 3> dimensions;
394 if (parser.parseDimensionList(dimensions, false, false))
395 return failure();
396 if (dimensions.size() != 3)
397 return parser.emitError(parser.getCurrentLocation())
398 << "expected 3 dimensions in MNK dimension list";
399
400 m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
401 n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
402 k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
403 return success();
404}
405
406LogicalResult WMMAOp::verify() {
407 auto sourceAType = cast<VectorType>(getSourceA().getType());
408 auto sourceBType = cast<VectorType>(getSourceB().getType());
409 auto destType = cast<VectorType>(getDestC().getType());
410
411 Type sourceAElemType = sourceAType.getElementType();
412 Type sourceBElemType = sourceBType.getElementType();
413 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
414 return emitOpError("source vectors have different lengths: ")
415 << sourceAType << " vs. " << sourceBType;
416 }
417
418 bool isDestFloat = destType.getElementType().isFloat();
419 bool isSrcFloat = sourceAElemType.isFloat();
420
421 if (isDestFloat && !isSrcFloat)
422 return emitOpError("expected float sources with float destination");
423 if (!isDestFloat && isSrcFloat)
424 return emitOpError("expected int sources with int destination");
425
426 if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
427 return emitOpError(
428 "source element types must match (except for fp8/bf8) but have ")
429 << sourceAType << " and " << sourceBType;
430 }
431
432 if (isSrcFloat) {
433 if (getClamp())
434 return emitOpError("clamp flag is not supported for float types");
435 if (getUnsignedA() || getUnsignedB())
436 return emitOpError("unsigned flags are not supported for float types");
437 }
438 return success();
439}
440
441//===----------------------------------------------------------------------===//
442// MFMAOp
443//===----------------------------------------------------------------------===//
444LogicalResult MFMAOp::verify() {
445 constexpr uint32_t waveSize = 64;
447
448 Type sourceType = getSourceA().getType();
449 Type destType = getDestC().getType();
450
451 Type sourceElem = sourceType, destElem = destType;
452 uint32_t sourceLen = 1, destLen = 1;
453 if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
454 sourceLen = sourceVector.getNumElements();
455 sourceElem = sourceVector.getElementType();
456 }
457 if (auto destVector = dyn_cast<VectorType>(destType)) {
458 destLen = destVector.getNumElements();
459 destElem = destVector.getElementType();
460 }
461
462 Type sourceBType = getSourceB().getType();
463 if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
464 int64_t sourceBLen = 1;
465 Type sourceBElem = sourceBType;
466 if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
467 sourceBLen = sourceBVector.getNumElements();
468 sourceBElem = sourceBVector.getElementType();
469 }
470 if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
471 !sourceBElem.isFloat(4))
472 return emitOpError("expected both source operands to have small-float "
473 "elements if one does");
474 if (sourceLen != sourceBLen)
475 return emitOpError(
476 "expected both small-float source vectors to have the same length");
477 } else {
478 if (sourceType != sourceBType)
479 return emitOpError("expected both non-small-float source operand types "
480 "to match exactly");
481 }
482 // Normalize the wider integer types the compiler expects to i8.
483 if (sourceElem.isInteger(32)) {
484 sourceLen *= 4;
485 sourceElem = b.getI8Type();
486 }
487 if (sourceElem.isInteger(64)) {
488 sourceLen *= 8;
489 sourceElem = b.getI8Type();
490 }
491
492 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
493 if (sourceLen != numSourceElems)
494 return emitOpError("expected " + Twine(numSourceElems) +
495 " source values for this operation but got " +
496 Twine(sourceLen));
497
498 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
499 if (destLen != numDestElems)
500 return emitOpError("expected " + Twine(numDestElems) +
501 " result values for this operation but got " +
502 Twine(destLen));
503
504 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
505 return emitOpError(
506 "double-precision ops do not support permuting lanes of B");
507 if (destElem.isF64() && getCbsz() != 0)
508 return emitOpError(
509 "double-precision ops do not support permuting lanes of A");
510 if (getAbid() >= (1u << getCbsz()))
511 return emitOpError(
512 "block ID for permuting A (abid) must be below 2 ** cbsz");
513
514 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
515 return emitOpError(
516 "negation flags only available for double-precision operations");
517
518 return success();
519}
520
521//===----------------------------------------------------------------------===//
522// DPPOp
523//===----------------------------------------------------------------------===//
524LogicalResult DPPOp::verify() {
525 Type srcType = getSrc().getType();
526 if (srcType.getIntOrFloatBitWidth() > 64) {
527 return emitOpError("integer and floating point types larger than 64 bits "
528 "are not supported");
529 }
530
531 DPPPerm kind = getKind();
532 Attribute permArgument = getPermArgument().value_or(Attribute{});
533
534 switch (kind) {
535
536 case DPPPerm::quad_perm: {
537 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
538 if (!quadPermAttr || quadPermAttr.size() != 4) {
539 return emitOpError("quad_perm attribute must have exactly 4 elements");
540 }
541 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
542 int32_t num = elem.getInt();
543 if (num < 0 || num > 3) {
544 return emitOpError(
545 "Each element of quad_perm must be in the range [0, 3]");
546 }
547 }
548 } break;
549
550 case DPPPerm::row_shl:
551 case DPPPerm::row_shr:
552 case DPPPerm::row_ror: {
553 if (!permArgument) {
554 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
555 "' value not specified");
556 }
557 if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
558 uint32_t attrValue = intAttr.getInt();
559 if (attrValue < 1 || attrValue > 15) {
560 return emitOpError("Attribute value must be between 1 and 15");
561 }
562 }
563 } break;
564
565 case DPPPerm::wave_shl:
566 case DPPPerm::wave_shr:
567 case DPPPerm::wave_rol:
568 case DPPPerm::wave_ror:
569 case DPPPerm::row_mirror:
570 case DPPPerm::row_half_mirror:
571 case DPPPerm::row_bcast_15:
572 case DPPPerm::row_bcast_31: {
573 if (permArgument && !isa<UnitAttr>(permArgument)) {
574 return emitOpError("Expected unit attribute for permArgument, but found "
575 "non-trivial argument");
576 }
577 break;
578 }
579 }
580 return success();
581}
582
583//===----------------------------------------------------------------------===//
584// PermlaneSwapOp
585//===----------------------------------------------------------------------===//
586LogicalResult PermlaneSwapOp::verify() {
587 unsigned rowLength = getRowLength();
588
589 if (rowLength != 16 && rowLength != 32)
590 return emitOpError("row_length attribute must either be 16 or 32.");
591
592 return success();
593}
594
595//===----------------------------------------------------------------------===//
596// GatherToLDSOp
597//===----------------------------------------------------------------------===//
598
599LogicalResult GatherToLDSOp::verify() {
600 MemRefType srcType = cast<MemRefType>(getSrc().getType());
601 MemRefType dstType = cast<MemRefType>(getDst().getType());
602
603 if (!dstType.areTrailingDimsContiguous(1))
604 return emitOpError("destination type inner most dim must be contiguous");
605
606 auto elemType = srcType.getElementType();
607 // Check $src and $dst element types are the same.
608 if (elemType != dstType.getElementType())
609 return emitOpError("source and destination element types must match");
610
611 // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
612 auto transferType = getTransferType();
613 int transferSize;
614 if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
615 transferSize = vectorTransfer.getNumElements() *
616 vectorTransfer.getElementTypeBitWidth();
617 } else {
618 transferSize = transferType.getIntOrFloatBitWidth();
619 }
620 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
621 return emitOpError(
622 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
623
624 if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
625 !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
626 return emitOpError(
627 "source memory address space must be global or fat raw buffer");
628
629 if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
630 return emitOpError("destination memory address space must be Workgroup");
631
632 return success();
633}
634
635namespace {
636/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
637/// information or changes layout, the cast can be skipped.
638struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
640
641 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
642 PatternRewriter &rewriter) const override {
643 bool modified = false;
644 auto foldCast = [&](OpOperand &operand) {
645 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
646 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
647 rewriter.modifyOpInPlace(gatherOp,
648 [&] { operand.assign(castOp.getSource()); });
649 modified = true;
650 }
651 }
652 };
653
654 foldCast(gatherOp.getSrcMutable());
655 foldCast(gatherOp.getDstMutable());
656
657 return success(modified);
658 }
659};
660} // namespace
661
662void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
663 MLIRContext *context) {
664 results.add<FoldGatherToLDSOfCast>(context);
665}
666
667//===----------------------------------------------------------------------===//
668// TransposeLoadOp
669//===----------------------------------------------------------------------===//
670
671LogicalResult TransposeLoadOp::verify() {
672 MemRefType srcType = cast<MemRefType>(getSrc().getType());
673
674 if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
675 return emitOpError("source memory address space must be Workgroup");
676
677 auto transferType = cast<VectorType>(getType());
678 size_t numElements = transferType.getNumElements();
679 size_t elementTypeSize =
680 transferType.getElementType().getIntOrFloatBitWidth();
681
682 // ElementSize -> NumElements
683 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
684 {4, 16},
685 {6, 16},
686 {8, 8},
687 {16, 4},
688 };
689
690 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
691 if (validNumElems == kValidLoadSizeMap.end()) {
692 return emitOpError("Unsupported element type size for transpose load: ")
693 << elementTypeSize << " bits";
694 }
695 if (numElements != validNumElems->second) {
696 return emitOpError(
697 "Transferring type size mismatch: expected num of elements: ")
698 << validNumElems->second;
699 }
700
701 return success();
702}
703
704//===----------------------------------------------------------------------===//
705// ScaledMFMAOp
706//===----------------------------------------------------------------------===//
707
708namespace {
709/// Check if the scales input is used in other scaled mfma's while they exist.
710/// If theyre unused then pack the scales.
711struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
713
714 LogicalResult matchAndRewrite(ScaledMFMAOp op,
715 PatternRewriter &rewriter) const override {
716 Location loc = op.getLoc();
717 auto setOpsel = [&op](unsigned idx, int64_t val) {
718 switch (idx) {
719 case 3:
720 op.setScalesIdxA(val);
721 break;
722 case 4:
723 op.setScalesIdxB(val);
724 break;
725 default:
726 break;
727 }
728 };
729
730 // For every scale operand of this ScaledMFMAOp, if the scale is produced by
731 // the extraction of a single scale from some vector, then attempt to
732 // extract 4 values from that vector instead.
733 //
734 // Example: (f8 here means f8E8M0FNU)
735 // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
736 // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
737 // amdgpu.scaled_mfma(%scale[0] * ...
738 //
739 // rewrite to:
740 //
741 // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
742 // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
743 // amdgpu.scaled_mfma(%scale[0-3] * ...
744 //
745 // This creates duplicate shape_casts for every use but these will be
746 // removed in CSE.
747 for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
748 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
749 if (!insertOp) {
750 return rewriter.notifyMatchFailure(op,
751 "defining op not a vector.insert");
752 }
753 // If the extracted value is not a single scalar, then it has been packed.
754 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
755 return rewriter.notifyMatchFailure(
756 op, "scaled mfma operand already packed");
757 }
758
759 auto extractOp =
760 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
761 if (!extractOp) {
762 return rewriter.notifyMatchFailure(op,
763 "defining op not a vector.extract");
764 }
765
766 Value scaleSrc = extractOp.getOperand(0);
767 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
768 if (!scaleSrcType) {
769 return rewriter.notifyMatchFailure(op, "not a vector type");
770 }
771
772 // We do not handle dynamic dims yet, assume that the input is padded to
773 // a static shape now.
774 if (!scaleSrcType.hasStaticShape()) {
775 return rewriter.notifyMatchFailure(op,
776 "dynamic dims not yet supported");
777 }
778
779 int64_t numElements = scaleSrcType.getNumElements();
780 if (numElements <= 4) {
781 return rewriter.notifyMatchFailure(
782 op, "no packing if # of scales less than four");
783 }
784
785 // Find a linearized idx using the size and offsets of the extract op.
786 auto extractedPos = llvm::to_vector_of<int64_t>(
787 llvm::reverse(extractOp.getStaticPosition()));
788 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
789 int64_t scaleSrcRank = scaleSrcType.getRank();
790 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
791 for (int64_t i = 1; i < scaleSrcRank; ++i) {
792 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
793 }
794 int64_t idx = linearize(extractedPos, extractSizes);
795
796 // All n scales (where n is the total number of scales) must now be
797 // extracted in chunks of 4 elements. This is done by dividing the
798 // original vector of scales into groups of 4 elements
799 // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
800 // scale at a particular index are now replaced with an extraction
801 // of the entire group of 4 elements to which that index belongs.
802 //
803 // If the number of scales happens to be indivisible by 4, extract
804 // the remaining n - m scales in a chunk of 4 elements starting at
805 // offset n - 4.
806 int64_t offset = idx - (idx % 4);
807 int64_t opsel = idx - offset;
808 int64_t size = 4l;
809 // Accomdate remaining elements in the case of non-4-divisible vectors.
810 if (numElements - offset < size) {
811 opsel = size - (numElements - idx);
812 offset = numElements - 4l;
813 }
814 Type scaleSrcElemType = scaleSrcType.getElementType();
815 auto newSrcType =
816 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
817 Value newScaleSrc =
818 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
819 auto extract = vector::ExtractStridedSliceOp::create(
820 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
821 ArrayRef{int64_t(1)});
822 rewriter.modifyOpInPlace(op, [&] {
823 op->setOperand(opIdx, extract);
824 setOpsel(opIdx, opsel);
825 });
826 }
827 return success();
828 }
829};
830} // namespace
831
832void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
833 MLIRContext *context) {
834 results.add<PackScales>(context);
835}
836
837#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
838
839#define GET_ATTRDEF_CLASSES
840#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
841
842#define GET_OP_CLASSES
843#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:45
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...