MLIR 22.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1//===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMDGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/IR/Builders.h"
24#include "mlir/IR/Diagnostics.h"
26#include "mlir/IR/Matchers.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34
35#include <algorithm>
36#include <cstdint>
37#include <limits>
38#include <optional>
39
40using namespace mlir;
41using namespace mlir::amdgpu;
42
43#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
44
45namespace {
46struct AMDGPUInlinerInterface final : DialectInlinerInterface {
47 using DialectInlinerInterface::DialectInlinerInterface;
48 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
49 return true;
50 }
51};
52} // namespace
53
54void AMDGPUDialect::initialize() {
55 addOperations<
56#define GET_OP_LIST
57#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
58 >();
59 addTypes<
60#define GET_TYPEDEF_LIST
61#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
62 >();
63 addAttributes<
64#define GET_ATTRDEF_LIST
65#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
66 >();
67 addInterfaces<AMDGPUInlinerInterface>();
68}
69
70//===----------------------------------------------------------------------===//
71// 8-bit float ops
72//===----------------------------------------------------------------------===//
73LogicalResult PackedTrunc2xFp8Op::verify() {
74 if (getExisting() && getExisting().getType() != getResult().getType())
75 return emitOpError("existing values must have same type as result");
76 return success();
77}
78
79LogicalResult PackedStochRoundFp8Op::verify() {
80 if (getExisting() && getExisting().getType() != getResult().getType())
81 return emitOpError("existing values must have same type as result");
82 return success();
83}
84
85//===----------------------------------------------------------------------===//
86// mxfp float ops
87//===----------------------------------------------------------------------===//
88LogicalResult PackedScaledTruncOp::verify() {
89 if (getExisting() && getExisting().getType() != getResult().getType())
90 return emitOpError("existing values must have same type as result");
91 return success();
92}
93
94//===----------------------------------------------------------------------===//
95// FatRawBufferCastOp
96//===----------------------------------------------------------------------===//
97
98/// Convert the type `source` to one with the same sizes and strides - and
99/// offset, unless `stripOffset` is true, in which case the offset is reset to
100/// 0, if the offset should be reset but the layout of `source` isn't either the
101/// identity layout or a strided layout, this function fails.
102static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
103 bool resetOffset) {
104 MLIRContext *ctx = source.getContext();
105 MemRefType::Builder mb(source);
107 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
108 MemRefLayoutAttrInterface layout = source.getLayout();
109 if (resetOffset && !layout.isIdentity()) {
110 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
111 if (!stridedLayout)
112 return failure();
113 MemRefLayoutAttrInterface newLayout =
114 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
115 // Special case: if resetting the offset causes the strided layout to become
116 // the identity layout, then reset to the identity layout.
117 // TODO: this'll get a lot simpler when we have the contiguous layout.
118 SmallVector<int64_t> stridesIfIdentity;
119 if (source.hasStaticShape()) {
120 stridesIfIdentity = computeSuffixProduct(source.getShape());
121 } else if (source.getRank() <= 1) {
122 stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
123 }
124 if (stridesIfIdentity == stridedLayout.getStrides()) {
125 newLayout = AffineMapAttr::get(
126 AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
127 }
128 mb.setLayout(newLayout);
129 }
130 return (MemRefType)(mb);
131}
132
133LogicalResult FatRawBufferCastOp::inferReturnTypes(
134 MLIRContext *context, std::optional<Location> location, ValueRange operands,
135 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
136 SmallVectorImpl<Type> &inferredReturnTypes) {
137 Adaptor adaptor(operands, attributes, properties, regions);
138 auto sourceType =
139 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
140 if (!sourceType)
141 return failure();
142 FailureOr<MemRefType> resultType =
143 getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
144 if (failed(resultType))
145 return failure();
146 inferredReturnTypes = SmallVector<Type>{*resultType};
147 return success();
148}
149
150FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(OpBuilder &builder,
151 int resultIndex,
152 int dim) {
153 assert(resultIndex == 0 && "FatRawBufferCastOp has a single result");
154 Value source = getSource();
155 auto sourceType = cast<MemRefType>(source.getType());
156 if (sourceType.isDynamicDim(dim))
157 return OpFoldResult(
158 builder.createOrFold<memref::DimOp>(getLoc(), source, dim));
159 return OpFoldResult(builder.getIndexAttr(sourceType.getDimSize(dim)));
160}
161
162LogicalResult FatRawBufferCastOp::verify() {
163 FailureOr<MemRefType> expectedResultType =
164 getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
165 if (failed(expectedResultType))
166 return emitOpError("source type ")
167 << getSource().getType() << " can't have its offset reset";
168 if (getResult().getType() != *expectedResultType)
169 return emitOpError("expected result type to be ")
170 << *expectedResultType << " but got " << getResult().getType();
171 return success();
172}
173
174static bool hasGlobalMemorySpace(Attribute memorySpace) {
175 if (!memorySpace)
176 return true;
177 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
178 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
179 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
180 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
181 return false;
182}
183
184static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
185 if (!memorySpace)
186 return false;
187 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
188 return intMemorySpace.getInt() == 3;
189 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
190 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
191 return false;
192}
193
194static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
195 if (!memorySpace)
196 return false;
197 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
198 return intMemorySpace.getInt() == 7;
199 if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
200 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
201 return false;
202}
203
204//===----------------------------------------------------------------------===//
205// RawBuffer*Op
206//===----------------------------------------------------------------------===//
207template <typename T>
208static LogicalResult verifyRawBufferOp(T &op) {
209 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
210 bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
211
212 if (!isGlobal)
213 return op.emitOpError(
214 "Buffer ops must operate on a memref in global memory");
215 if (!bufferType.hasRank())
216 return op.emitOpError(
217 "Cannot meaningfully buffer_store to an unranked memref");
218 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
219 return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
220 " indices to memref");
221 return success();
222}
223
224LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
225
226LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
227
228LogicalResult RawBufferAtomicFaddOp::verify() {
229 return verifyRawBufferOp(*this);
230}
231
232LogicalResult RawBufferAtomicFmaxOp::verify() {
233 return verifyRawBufferOp(*this);
234}
235
236LogicalResult RawBufferAtomicSmaxOp::verify() {
237 return verifyRawBufferOp(*this);
238}
239
240LogicalResult RawBufferAtomicUminOp::verify() {
241 return verifyRawBufferOp(*this);
242}
243
244LogicalResult RawBufferAtomicCmpswapOp::verify() {
245 return verifyRawBufferOp(*this);
246}
247
248static std::optional<uint32_t> getConstantUint32(Value v) {
249 APInt cst;
250 if (!v.getType().isInteger(32))
251 return std::nullopt;
252 if (matchPattern(v, m_ConstantInt(&cst)))
253 return cst.getZExtValue();
254 return std::nullopt;
255}
256
257template <typename OpType>
258static bool staticallyOutOfBounds(OpType op) {
259 if (!op.getBoundsCheck())
260 return false;
261 MemRefType bufferType = op.getMemref().getType();
262 if (!bufferType.hasStaticShape())
263 return false;
264 int64_t offset;
265 SmallVector<int64_t> strides;
266 if (failed(bufferType.getStridesAndOffset(strides, offset)))
267 return false;
268 int64_t result = offset + op.getIndexOffset().value_or(0);
269 if (op.getSgprOffset()) {
270 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
271 if (!sgprOffset)
272 return false;
273 result += *sgprOffset;
274 }
275 if (strides.size() != op.getIndices().size())
276 return false;
277 int64_t indexVal = 0;
278 for (auto pair : llvm::zip(strides, op.getIndices())) {
279 int64_t stride = std::get<0>(pair);
280 Value idx = std::get<1>(pair);
281 std::optional<uint32_t> idxVal = getConstantUint32(idx);
282 if (!idxVal)
283 return false;
284 indexVal += stride * *idxVal;
285 }
286 result += indexVal;
287 if (result > std::numeric_limits<uint32_t>::max())
288 // Overflow means don't drop
289 return false;
290 return result >= bufferType.getNumElements();
291}
292
293namespace {
294template <typename OpType>
295struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
296 using OpRewritePattern<OpType>::OpRewritePattern;
297
298 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
299 if (!staticallyOutOfBounds(op))
300 return failure();
301 Type loadType = op.getResult().getType();
302 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
303 rw.getZeroAttr(loadType));
304 return success();
305 }
306};
307
308template <typename OpType>
309struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
310 using OpRewritePattern<OpType>::OpRewritePattern;
311
312 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
313 if (!staticallyOutOfBounds(op))
314 return failure();
315
316 rw.eraseOp(op);
317 return success();
318 }
319};
320} // end namespace
321
322void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
323 MLIRContext *context) {
324 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
325}
326
327void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
328 MLIRContext *context) {
329 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
330}
331
332void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
333 RewritePatternSet &results, MLIRContext *context) {
334 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
335}
336
337void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
338 RewritePatternSet &results, MLIRContext *context) {
339 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
340}
341
342void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
343 RewritePatternSet &results, MLIRContext *context) {
344 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
345}
346
347void RawBufferAtomicUminOp::getCanonicalizationPatterns(
348 RewritePatternSet &results, MLIRContext *context) {
349 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
350}
351
352void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
353 RewritePatternSet &results, MLIRContext *context) {
354 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
355 context);
356}
357
358//===----------------------------------------------------------------------===//
359// ScaledExtPackedMatrixOp
360//===----------------------------------------------------------------------===//
361LogicalResult ScaledExtPackedMatrixOp::verify() {
362 int blockSize = getBlockSize();
363 assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
364
365 int firstScaleByte = getFirstScaleByte();
366 int firstScaleLane = getFirstScaleLane();
367 auto sourceType = cast<VectorType>(getSource().getType());
368 Type elementType = sourceType.getElementType();
369 auto floatType = cast<FloatType>(elementType);
370 unsigned bitWidth = floatType.getWidth();
371
372 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
373
374 const bool is_fp8 = bitWidth == 8;
375 const bool is_block_16 = blockSize == 16;
376
377 if (!is_fp8) {
378 if (is_block_16) {
379 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
380 return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
381 "or 1 for f4 and f6.");
382 }
383 } else {
384 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
385 return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
386 "or 2 for f4 and f6.");
387 }
388 }
389 } else {
390 if (is_block_16) {
391 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
392 ((firstScaleLane == 16) && (firstScaleByte == 2));
393 if (!is_valid) {
394 return emitOpError("blockSize of 16 can only have (firstScaleLane, "
395 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
396 }
397 }
398 }
399
400 return success();
401}
402
403//===----------------------------------------------------------------------===//
404// WMMAOp
405//===----------------------------------------------------------------------===//
406
408 IntegerAttr &m, IntegerAttr &n,
409 IntegerAttr &k) {
410 SmallVector<int64_t, 3> dimensions;
411 if (parser.parseDimensionList(dimensions, false, false))
412 return failure();
413 if (dimensions.size() != 3)
414 return parser.emitError(parser.getCurrentLocation())
415 << "expected 3 dimensions in MNK dimension list";
416
417 m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
418 n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
419 k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
420 return success();
421}
422
423LogicalResult WMMAOp::verify() {
424 auto sourceAType = cast<VectorType>(getSourceA().getType());
425 auto sourceBType = cast<VectorType>(getSourceB().getType());
426 auto destType = cast<VectorType>(getDestC().getType());
427
428 Type sourceAElemType = sourceAType.getElementType();
429 Type sourceBElemType = sourceBType.getElementType();
430 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
431 return emitOpError("source vectors have different lengths: ")
432 << sourceAType << " vs. " << sourceBType;
433 }
434
435 bool isDestFloat = destType.getElementType().isFloat();
436 bool isSrcFloat = sourceAElemType.isFloat();
437
438 if (isDestFloat && !isSrcFloat)
439 return emitOpError("expected float sources with float destination");
440 if (!isDestFloat && isSrcFloat)
441 return emitOpError("expected int sources with int destination");
442
443 if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
444 return emitOpError(
445 "source element types must match (except for fp8/bf8) but have ")
446 << sourceAType << " and " << sourceBType;
447 }
448
449 if (isSrcFloat) {
450 if (getClamp())
451 return emitOpError("clamp flag is not supported for float types");
452 if (getUnsignedA() || getUnsignedB())
453 return emitOpError("unsigned flags are not supported for float types");
454 }
455 return success();
456}
457
458//===----------------------------------------------------------------------===//
459// MFMAOp
460//===----------------------------------------------------------------------===//
461LogicalResult MFMAOp::verify() {
462 constexpr uint32_t waveSize = 64;
464
465 Type sourceType = getSourceA().getType();
466 Type destType = getDestC().getType();
467
468 Type sourceElem = sourceType, destElem = destType;
469 uint32_t sourceLen = 1, destLen = 1;
470 if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
471 sourceLen = sourceVector.getNumElements();
472 sourceElem = sourceVector.getElementType();
473 }
474 if (auto destVector = dyn_cast<VectorType>(destType)) {
475 destLen = destVector.getNumElements();
476 destElem = destVector.getElementType();
477 }
478
479 Type sourceBType = getSourceB().getType();
480 if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
481 int64_t sourceBLen = 1;
482 Type sourceBElem = sourceBType;
483 if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
484 sourceBLen = sourceBVector.getNumElements();
485 sourceBElem = sourceBVector.getElementType();
486 }
487 if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
488 !sourceBElem.isFloat(4))
489 return emitOpError("expected both source operands to have small-float "
490 "elements if one does");
491 if (sourceLen != sourceBLen)
492 return emitOpError(
493 "expected both small-float source vectors to have the same length");
494 } else {
495 if (sourceType != sourceBType)
496 return emitOpError("expected both non-small-float source operand types "
497 "to match exactly");
498 }
499 // Normalize the wider integer types the compiler expects to i8.
500 if (sourceElem.isInteger(32)) {
501 sourceLen *= 4;
502 sourceElem = b.getI8Type();
503 }
504 if (sourceElem.isInteger(64)) {
505 sourceLen *= 8;
506 sourceElem = b.getI8Type();
507 }
508
509 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
510 if (sourceLen != numSourceElems)
511 return emitOpError("expected " + Twine(numSourceElems) +
512 " source values for this operation but got " +
513 Twine(sourceLen));
514
515 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
516 if (destLen != numDestElems)
517 return emitOpError("expected " + Twine(numDestElems) +
518 " result values for this operation but got " +
519 Twine(destLen));
520
521 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
522 return emitOpError(
523 "double-precision ops do not support permuting lanes of B");
524 if (destElem.isF64() && getCbsz() != 0)
525 return emitOpError(
526 "double-precision ops do not support permuting lanes of A");
527 if (getAbid() >= (1u << getCbsz()))
528 return emitOpError(
529 "block ID for permuting A (abid) must be below 2 ** cbsz");
530
531 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
532 return emitOpError(
533 "negation flags only available for double-precision operations");
534
535 return success();
536}
537
538//===----------------------------------------------------------------------===//
539// DPPOp
540//===----------------------------------------------------------------------===//
541LogicalResult DPPOp::verify() {
542 Type srcType = getSrc().getType();
543 if (srcType.getIntOrFloatBitWidth() > 64) {
544 return emitOpError("integer and floating point types larger than 64 bits "
545 "are not supported");
546 }
547
548 DPPPerm kind = getKind();
549 Attribute permArgument = getPermArgument().value_or(Attribute{});
550
551 switch (kind) {
552
553 case DPPPerm::quad_perm: {
554 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
555 if (!quadPermAttr || quadPermAttr.size() != 4) {
556 return emitOpError("quad_perm attribute must have exactly 4 elements");
557 }
558 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
559 int32_t num = elem.getInt();
560 if (num < 0 || num > 3) {
561 return emitOpError(
562 "Each element of quad_perm must be in the range [0, 3]");
563 }
564 }
565 } break;
566
567 case DPPPerm::row_shl:
568 case DPPPerm::row_shr:
569 case DPPPerm::row_ror: {
570 if (!permArgument) {
571 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
572 "' value not specified");
573 }
574 if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
575 uint32_t attrValue = intAttr.getInt();
576 if (attrValue < 1 || attrValue > 15) {
577 return emitOpError("Attribute value must be between 1 and 15");
578 }
579 }
580 } break;
581
582 case DPPPerm::wave_shl:
583 case DPPPerm::wave_shr:
584 case DPPPerm::wave_rol:
585 case DPPPerm::wave_ror:
586 case DPPPerm::row_mirror:
587 case DPPPerm::row_half_mirror:
588 case DPPPerm::row_bcast_15:
589 case DPPPerm::row_bcast_31: {
590 if (permArgument && !isa<UnitAttr>(permArgument)) {
591 return emitOpError("Expected unit attribute for permArgument, but found "
592 "non-trivial argument");
593 }
594 break;
595 }
596 }
597 return success();
598}
599
600//===----------------------------------------------------------------------===//
601// PermlaneSwapOp
602//===----------------------------------------------------------------------===//
603LogicalResult PermlaneSwapOp::verify() {
604 unsigned rowLength = getRowLength();
605
606 if (rowLength != 16 && rowLength != 32)
607 return emitOpError("row_length attribute must either be 16 or 32.");
608
609 return success();
610}
611
612//===----------------------------------------------------------------------===//
613// MemoryCounterWaitOp
614//===----------------------------------------------------------------------===//
615
616namespace {
617/// Fuse adjacent memory counter wait ops, taking the minimum value of the
618/// counters.
619struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
620 using Base::Base;
621
622 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
623 PatternRewriter &rewriter) const override {
624 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
625 if (!next)
626 return failure();
627
628 auto setters = {&MemoryCounterWaitOp::setLoad,
629 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
630 &MemoryCounterWaitOp::setExp,
631 &MemoryCounterWaitOp::setTensor};
632 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
633 op.getTensor()};
634 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
635 next.getExp(), next.getTensor()};
636 rewriter.modifyOpInPlace(op, [&] {
637 for (auto [setter, lhs, rhs] :
638 llvm::zip_equal(setters, lhsVals, rhsVals)) {
639 if (lhs && rhs) {
640 (op.*setter)(std::min(*lhs, *rhs));
641 } else if (lhs) {
642 (op.*setter)(*lhs);
643 } else if (rhs) {
644 (op.*setter)(*rhs);
645 }
646 }
647 });
648 rewriter.eraseOp(next);
649 return success();
650 }
651};
652} // namespace
653
654void MemoryCounterWaitOp::getCanonicalizationPatterns(
655 RewritePatternSet &results, MLIRContext *context) {
656 results.add<FuseMemoryCounterWaitOp>(context);
657}
658
659//===----------------------------------------------------------------------===//
660// GatherToLDSOp
661//===----------------------------------------------------------------------===//
662
663LogicalResult GatherToLDSOp::verify() {
664 MemRefType srcType = cast<MemRefType>(getSrc().getType());
665 MemRefType dstType = cast<MemRefType>(getDst().getType());
666
667 if (!dstType.areTrailingDimsContiguous(1))
668 return emitOpError("destination type inner most dim must be contiguous");
669
670 auto elemType = srcType.getElementType();
671 // Check $src and $dst element types are the same.
672 if (elemType != dstType.getElementType())
673 return emitOpError("source and destination element types must match");
674
675 // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
676 auto transferType = getTransferType();
677 int transferSize;
678 if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
679 transferSize = vectorTransfer.getNumElements() *
680 vectorTransfer.getElementTypeBitWidth();
681 } else {
682 transferSize = transferType.getIntOrFloatBitWidth();
683 }
684 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
685 return emitOpError(
686 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
687
688 if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
689 !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
690 return emitOpError(
691 "source memory address space must be global or fat raw buffer");
692
693 if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
694 return emitOpError("destination memory address space must be Workgroup");
695
696 return success();
697}
698
699namespace {
700/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
701/// information or changes layout, the cast can be skipped.
702struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
704
705 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
706 PatternRewriter &rewriter) const override {
707 bool modified = false;
708 auto foldCast = [&](OpOperand &operand) {
709 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
710 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
711 rewriter.modifyOpInPlace(gatherOp,
712 [&] { operand.assign(castOp.getSource()); });
713 modified = true;
714 }
715 }
716 };
717
718 foldCast(gatherOp.getSrcMutable());
719 foldCast(gatherOp.getDstMutable());
720
721 return success(modified);
722 }
723};
724} // namespace
725
726void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
727 MLIRContext *context) {
728 results.add<FoldGatherToLDSOfCast>(context);
729}
730
731//===----------------------------------------------------------------------===//
732// TransposeLoadOp
733//===----------------------------------------------------------------------===//
734
735LogicalResult TransposeLoadOp::verify() {
736 MemRefType srcType = cast<MemRefType>(getSrc().getType());
737
738 if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
739 return emitOpError("source memory address space must be Workgroup");
740
741 auto transferType = cast<VectorType>(getType());
742 size_t numElements = transferType.getNumElements();
743 size_t elementTypeSize =
744 transferType.getElementType().getIntOrFloatBitWidth();
745
746 // ElementSize -> NumElements
747 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
748 {4, 16},
749 {6, 16},
750 {8, 8},
751 {16, 4},
752 };
753
754 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
755 if (validNumElems == kValidLoadSizeMap.end())
756 return emitOpError("Unsupported element type size for transpose load: ")
757 << elementTypeSize << " bits";
758
759 if (numElements != validNumElems->second)
760 return emitOpError(
761 "Transferring type size mismatch: expected num of elements: ")
762 << validNumElems->second;
763
764 return success();
765}
766
767//===----------------------------------------------------------------------===//
768// MakeDmaBaseOp
769//===----------------------------------------------------------------------===//
770
771template <typename BaseOp>
772static LogicalResult verifyBase(BaseOp op) {
773 auto ldsType = cast<MemRefType>(op.getLds().getType());
774 auto globalType = cast<MemRefType>(op.getGlobal().getType());
775 if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
776 return op.emitOpError(
777 "lds memref must have workgroup address space attribute.");
778 if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
779 return op.emitOpError(
780 "global memref must have global address space attribute.");
781
782 Type elementType = ldsType.getElementType();
783 unsigned width = elementType.getIntOrFloatBitWidth();
784
785 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
786 return op.emitOpError(
787 "element type must be 1, 2, 4, or 8 bytes long but type was ")
788 << width << " bits long.";
789 return success();
790}
791
792LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
793
794//===----------------------------------------------------------------------===//
795// MakeGatherDmaBaseOp
796//===----------------------------------------------------------------------===//
797
798LogicalResult
799TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
800 Type elementType, Type indexType) {
801 unsigned width = elementType.getIntOrFloatBitWidth();
802 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
803 return emitError()
804 << "element type must be 1, 2, 4, or 8 bytes wide but type "
805 << elementType << " is " << width / 8 << " bytes wide.";
806 MLIRContext *ctx = elementType.getContext();
807 Type i16 = IntegerType::get(ctx, 32);
808 Type i32 = IntegerType::get(ctx, 16);
809 if (!llvm::is_contained({i16, i32}, indexType))
810 return emitError() << "index type must be i16 or i32 but index type is "
811 << indexType << ".";
812 return success();
813}
814
815LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
816
817//===----------------------------------------------------------------------===//
818// MakeDmaDescriptorOp
819//===----------------------------------------------------------------------===//
820
821LogicalResult MakeDmaDescriptorOp::verify() {
822 ArrayRef<int64_t> globalStaticStrides = getGlobalStaticStrides();
823
824 if (globalStaticStrides.empty())
825 return emitOpError("strides must not be empty.");
826 if (globalStaticStrides.back() != 1)
827 return emitOpError("strides for the innermost dimension must be 1.");
828
829 ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
830 size_t rank = globalStaticSizes.size();
831 if (rank > 5)
832 return emitOpError("tensor and tile must be at most of rank 5.");
833 if (rank != globalStaticStrides.size())
834 return emitOpError("strides and sizes must have same rank.");
835
836 ArrayRef<int64_t> sharedStaticSizes = getSharedStaticSizes();
837 if (rank != sharedStaticSizes.size())
838 return emitOpError("tensor must have same rank as tile.");
839
840 unsigned elementTypeWidth = getElementTypeWidth();
841 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
842 return emitOpError(
843 "element type width must be 1, 2, 4 or 8 bytes, but was ")
844 << elementTypeWidth << " bits long";
845
846 if (Value atomicBarrierAddress = getAtomicBarrierAddress()) {
847 auto atomicBarrierAddressType =
848 cast<MemRefType>(atomicBarrierAddress.getType());
849 bool barrierInLDS =
850 hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace());
851 if (!barrierInLDS)
852 return emitOpError("atomic barrier address must be in LDS.");
853 }
854
855 if (getEarlyTimeout() && !getWorkgroupMask())
856 return emitOpError(
857 "early timeout does not apply when workgroup_mask is not set.");
858 return success();
859}
860
861OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
862 SmallVector<OpFoldResult> mixedGlobalSizes(getMixedGlobalSizes());
863 SmallVector<OpFoldResult> mixedGlobalStrides(getMixedGlobalStrides());
864 SmallVector<OpFoldResult> mixedSharedSizes(getMixedSharedSizes());
865
866 if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true,
867 /*onlyNonZero=*/true)) &&
868 failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true,
869 /*onlyNonZero=*/true)) &&
870 failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true,
871 /*onlyNonZero=*/true)))
872 return nullptr;
873
874 SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides,
875 dynamicSharedSizes;
876 SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides,
877 staticSharedSizes;
878
879 dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes,
880 staticGlobalSizes);
881 setGlobalStaticSizes(staticGlobalSizes);
882 getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
883
884 dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides,
885 staticGlobalStrides);
886 setGlobalStaticStrides(staticGlobalStrides);
887 getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
888
889 dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes,
890 staticSharedSizes);
891 setSharedStaticSizes(staticSharedSizes);
892 getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
893 return getResult();
894}
895
896//===----------------------------------------------------------------------===//
897// ScaledMFMAOp
898//===----------------------------------------------------------------------===//
899
900namespace {
901/// Check if the scales input is used in other scaled mfma's while they exist.
902/// If theyre unused then pack the scales.
903struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
905
906 LogicalResult matchAndRewrite(ScaledMFMAOp op,
907 PatternRewriter &rewriter) const override {
908 Location loc = op.getLoc();
909 auto setOpsel = [&op](unsigned idx, int64_t val) {
910 switch (idx) {
911 case 3:
912 op.setScalesIdxA(val);
913 break;
914 case 4:
915 op.setScalesIdxB(val);
916 break;
917 default:
918 break;
919 }
920 };
921
922 // For every scale operand of this ScaledMFMAOp, if the scale is produced by
923 // the extraction of a single scale from some vector, then attempt to
924 // extract 4 values from that vector instead.
925 //
926 // Example: (f8 here means f8E8M0FNU)
927 // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
928 // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
929 // amdgpu.scaled_mfma(%scale[0] * ...
930 //
931 // rewrite to:
932 //
933 // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
934 // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
935 // amdgpu.scaled_mfma(%scale[0-3] * ...
936 //
937 // This creates duplicate shape_casts for every use but these will be
938 // removed in CSE.
939 for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
940 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
941 if (!insertOp) {
942 return rewriter.notifyMatchFailure(op,
943 "defining op not a vector.insert");
944 }
945 // If the extracted value is not a single scalar, then it has been packed.
946 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
947 return rewriter.notifyMatchFailure(
948 op, "scaled mfma operand already packed");
949 }
950
951 auto extractOp =
952 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
953 if (!extractOp) {
954 return rewriter.notifyMatchFailure(op,
955 "defining op not a vector.extract");
956 }
957
958 Value scaleSrc = extractOp.getOperand(0);
959 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
960 if (!scaleSrcType) {
961 return rewriter.notifyMatchFailure(op, "not a vector type");
962 }
963
964 // We do not handle dynamic dims yet, assume that the input is padded to
965 // a static shape now.
966 if (!scaleSrcType.hasStaticShape()) {
967 return rewriter.notifyMatchFailure(op,
968 "dynamic dims not yet supported");
969 }
970
971 int64_t numElements = scaleSrcType.getNumElements();
972 if (numElements <= 4) {
973 return rewriter.notifyMatchFailure(
974 op, "no packing if # of scales less than four");
975 }
976
977 // Find a linearized idx using the size and offsets of the extract op.
978 auto extractedPos = llvm::to_vector_of<int64_t>(
979 llvm::reverse(extractOp.getStaticPosition()));
980 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
981 int64_t scaleSrcRank = scaleSrcType.getRank();
982 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
983 for (int64_t i = 1; i < scaleSrcRank; ++i) {
984 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
985 }
986 int64_t idx = linearize(extractedPos, extractSizes);
987
988 // All n scales (where n is the total number of scales) must now be
989 // extracted in chunks of 4 elements. This is done by dividing the
990 // original vector of scales into groups of 4 elements
991 // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
992 // scale at a particular index are now replaced with an extraction
993 // of the entire group of 4 elements to which that index belongs.
994 //
995 // If the number of scales happens to be indivisible by 4, extract
996 // the remaining n - m scales in a chunk of 4 elements starting at
997 // offset n - 4.
998 int64_t offset = idx - (idx % 4);
999 int64_t opsel = idx - offset;
1000 int64_t size = 4l;
1001 // Accomdate remaining elements in the case of non-4-divisible vectors.
1002 if (numElements - offset < size) {
1003 opsel = size - (numElements - idx);
1004 offset = numElements - 4l;
1005 }
1006 Type scaleSrcElemType = scaleSrcType.getElementType();
1007 auto newSrcType =
1008 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1009 Value newScaleSrc =
1010 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1011 auto extract = vector::ExtractStridedSliceOp::create(
1012 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1013 ArrayRef{int64_t(1)});
1014 rewriter.modifyOpInPlace(op, [&] {
1015 op->setOperand(opIdx, extract);
1016 setOpsel(opIdx, opsel);
1017 });
1018 }
1019 return success();
1020 }
1021};
1022} // namespace
1023
1024void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
1025 MLIRContext *context) {
1026 results.add<PackScales>(context);
1027}
1028
1029#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
1030
1031#define GET_ATTRDEF_CLASSES
1032#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
1033
1034#define GET_TYPEDEF_CLASSES
1035#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
1036
1037#define GET_OP_CLASSES
1038#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:45
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...