MLIR  22.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 
34 #include <algorithm>
35 #include <cstdint>
36 #include <limits>
37 #include <optional>
38 
39 using namespace mlir;
40 using namespace mlir::amdgpu;
41 
42 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
43 
44 namespace {
45 struct AMDGPUInlinerInterface final : DialectInlinerInterface {
47  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
48  return true;
49  }
50 };
51 } // namespace
52 
53 void AMDGPUDialect::initialize() {
54  addOperations<
55 #define GET_OP_LIST
56 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
57  >();
58  addAttributes<
59 #define GET_ATTRDEF_LIST
60 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
61  >();
62  addInterfaces<AMDGPUInlinerInterface>();
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // 8-bit float ops
67 //===----------------------------------------------------------------------===//
68 LogicalResult PackedTrunc2xFp8Op::verify() {
69  if (getExisting() && getExisting().getType() != getResult().getType())
70  return emitOpError("existing values must have same type as result");
71  return success();
72 }
73 
74 LogicalResult PackedStochRoundFp8Op::verify() {
75  if (getExisting() && getExisting().getType() != getResult().getType())
76  return emitOpError("existing values must have same type as result");
77  return success();
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // mxfp float ops
82 //===----------------------------------------------------------------------===//
83 LogicalResult PackedScaledTruncOp::verify() {
84  if (getExisting() && getExisting().getType() != getResult().getType())
85  return emitOpError("existing values must have same type as result");
86  return success();
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // FatRawBufferCastOp
91 //===----------------------------------------------------------------------===//
92 
93 /// Convert the type `source` to one with the same sizes and strides - and
94 /// offset, unless `stripOffset` is true, in which case the offset is reset to
95 /// 0, if the offset should be reset but the layout of `source` isn't either the
96 /// identity layout or a strided layout, this function fails.
97 static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
98  bool resetOffset) {
99  MLIRContext *ctx = source.getContext();
100  MemRefType::Builder mb(source);
101  mb.setMemorySpace(
102  amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
103  MemRefLayoutAttrInterface layout = source.getLayout();
104  if (resetOffset && !layout.isIdentity()) {
105  auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
106  if (!stridedLayout)
107  return failure();
108  MemRefLayoutAttrInterface newLayout =
109  StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
110  // Special case: if resetting the offset causes the strided layout to become
111  // the identity layout, then reset to the identity layout.
112  // TODO: this'll get a lot simpler when we have the contiguous layout.
113  SmallVector<int64_t> stridesIfIdentity;
114  if (source.hasStaticShape()) {
115  stridesIfIdentity = computeSuffixProduct(source.getShape());
116  } else if (source.getRank() <= 1) {
117  stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
118  }
119  if (stridesIfIdentity == stridedLayout.getStrides()) {
120  newLayout = AffineMapAttr::get(
121  AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
122  }
123  mb.setLayout(newLayout);
124  }
125  return (MemRefType)(mb);
126 }
127 
128 LogicalResult FatRawBufferCastOp::inferReturnTypes(
129  MLIRContext *context, std::optional<Location> location, ValueRange operands,
130  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
131  SmallVectorImpl<Type> &inferredReturnTypes) {
132  Adaptor adaptor(operands, attributes, properties, regions);
133  auto sourceType =
134  dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
135  if (!sourceType)
136  return failure();
137  FailureOr<MemRefType> resultType =
138  getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
139  if (failed(resultType))
140  return failure();
141  inferredReturnTypes = SmallVector<Type>{*resultType};
142  return success();
143 }
144 
145 LogicalResult FatRawBufferCastOp::verify() {
146  FailureOr<MemRefType> expectedResultType =
147  getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
148  if (failed(expectedResultType))
149  return emitOpError("source type ")
150  << getSource().getType() << " can't have its offset reset";
151  if (getResult().getType() != *expectedResultType)
152  return emitOpError("expected result type to be ")
153  << *expectedResultType << " but got " << getResult().getType();
154  return success();
155 }
156 
157 static bool hasGlobalMemorySpace(Attribute memorySpace) {
158  if (!memorySpace)
159  return true;
160  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
161  return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
162  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
163  return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
164  return false;
165 }
166 
167 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
168  if (!memorySpace)
169  return false;
170  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
171  return intMemorySpace.getInt() == 3;
172  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
173  return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
174  return false;
175 }
176 
177 static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
178  if (!memorySpace)
179  return false;
180  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
181  return intMemorySpace.getInt() == 7;
182  if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
183  return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
184  return false;
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // RawBuffer*Op
189 //===----------------------------------------------------------------------===//
190 template <typename T>
191 static LogicalResult verifyRawBufferOp(T &op) {
192  MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
193  bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
194 
195  if (!isGlobal)
196  return op.emitOpError(
197  "Buffer ops must operate on a memref in global memory");
198  if (!bufferType.hasRank())
199  return op.emitOpError(
200  "Cannot meaningfully buffer_store to an unranked memref");
201  if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
202  return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
203  " indices to memref");
204  return success();
205 }
206 
207 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
208 
209 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
210 
211 LogicalResult RawBufferAtomicFaddOp::verify() {
212  return verifyRawBufferOp(*this);
213 }
214 
215 LogicalResult RawBufferAtomicFmaxOp::verify() {
216  return verifyRawBufferOp(*this);
217 }
218 
219 LogicalResult RawBufferAtomicSmaxOp::verify() {
220  return verifyRawBufferOp(*this);
221 }
222 
223 LogicalResult RawBufferAtomicUminOp::verify() {
224  return verifyRawBufferOp(*this);
225 }
226 
227 LogicalResult RawBufferAtomicCmpswapOp::verify() {
228  return verifyRawBufferOp(*this);
229 }
230 
231 static std::optional<uint32_t> getConstantUint32(Value v) {
232  APInt cst;
233  if (!v.getType().isInteger(32))
234  return std::nullopt;
235  if (matchPattern(v, m_ConstantInt(&cst)))
236  return cst.getZExtValue();
237  return std::nullopt;
238 }
239 
240 template <typename OpType>
241 static bool staticallyOutOfBounds(OpType op) {
242  if (!op.getBoundsCheck())
243  return false;
244  MemRefType bufferType = op.getMemref().getType();
245  if (!bufferType.hasStaticShape())
246  return false;
247  int64_t offset;
248  SmallVector<int64_t> strides;
249  if (failed(bufferType.getStridesAndOffset(strides, offset)))
250  return false;
251  int64_t result = offset + op.getIndexOffset().value_or(0);
252  if (op.getSgprOffset()) {
253  std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
254  if (!sgprOffset)
255  return false;
256  result += *sgprOffset;
257  }
258  if (strides.size() != op.getIndices().size())
259  return false;
260  int64_t indexVal = 0;
261  for (auto pair : llvm::zip(strides, op.getIndices())) {
262  int64_t stride = std::get<0>(pair);
263  Value idx = std::get<1>(pair);
264  std::optional<uint32_t> idxVal = getConstantUint32(idx);
265  if (!idxVal)
266  return false;
267  indexVal += stride * *idxVal;
268  }
269  result += indexVal;
270  if (result > std::numeric_limits<uint32_t>::max())
271  // Overflow means don't drop
272  return false;
273  return result >= bufferType.getNumElements();
274 }
275 
276 namespace {
277 template <typename OpType>
278 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
280 
281  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
282  if (!staticallyOutOfBounds(op))
283  return failure();
284  Type loadType = op.getResult().getType();
285  rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
286  rw.getZeroAttr(loadType));
287  return success();
288  }
289 };
290 
291 template <typename OpType>
292 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
294 
295  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
296  if (!staticallyOutOfBounds(op))
297  return failure();
298 
299  rw.eraseOp(op);
300  return success();
301  }
302 };
303 } // end namespace
304 
305 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
306  MLIRContext *context) {
307  results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
308 }
309 
310 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
311  MLIRContext *context) {
312  results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
313 }
314 
315 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
316  RewritePatternSet &results, MLIRContext *context) {
317  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
318 }
319 
320 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
321  RewritePatternSet &results, MLIRContext *context) {
322  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
323 }
324 
325 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
326  RewritePatternSet &results, MLIRContext *context) {
327  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
328 }
329 
330 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
331  RewritePatternSet &results, MLIRContext *context) {
332  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
333 }
334 
335 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
336  RewritePatternSet &results, MLIRContext *context) {
337  results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
338  context);
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // ScaledExtPacked816Op
343 //===----------------------------------------------------------------------===//
344 LogicalResult ScaledExtPacked816Op::verify() {
345  int blockSize = getBlockSize();
346  assert((blockSize == 16 || blockSize == 32) && "invalid block size");
347  int firstScaleByte = getFirstScaleByte();
348  if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
349  return emitOpError(
350  "blockSize of 16 can only have firstScaleByte be 0 or 1.");
351  }
352  if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
353  return emitOpError(
354  "blockSize of 32 can only have firstScaleByte be 0 or 2.");
355  }
356 
357  return success();
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // WMMAOp
362 //===----------------------------------------------------------------------===//
363 
365  IntegerAttr &m, IntegerAttr &n,
366  IntegerAttr &k) {
367  SmallVector<int64_t, 3> dimensions;
368  if (parser.parseDimensionList(dimensions, false, false))
369  return failure();
370  if (dimensions.size() != 3)
371  return parser.emitError(parser.getCurrentLocation())
372  << "expected 3 dimensions in MNK dimension list";
373 
374  m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
375  n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
376  k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
377  return success();
378 }
379 
380 LogicalResult WMMAOp::verify() {
381  auto sourceAType = cast<VectorType>(getSourceA().getType());
382  auto sourceBType = cast<VectorType>(getSourceB().getType());
383  auto destType = cast<VectorType>(getDestC().getType());
384 
385  Type sourceAElemType = sourceAType.getElementType();
386  Type sourceBElemType = sourceBType.getElementType();
387  if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
388  return emitOpError("source vectors have different lengths: ")
389  << sourceAType << " vs. " << sourceBType;
390  }
391 
392  bool isDestFloat = destType.getElementType().isFloat();
393  bool isSrcFloat = sourceAElemType.isFloat();
394 
395  if (isDestFloat && !isSrcFloat)
396  return emitOpError("expected float sources with float destination");
397  if (!isDestFloat && isSrcFloat)
398  return emitOpError("expected int sources with int destination");
399 
400  if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
401  return emitOpError(
402  "source element types must match (except for fp8/bf8) but have ")
403  << sourceAType << " and " << sourceBType;
404  }
405 
406  if (isSrcFloat) {
407  if (getClamp())
408  return emitOpError("clamp flag is not supported for float types");
409  if (getUnsignedA() || getUnsignedB())
410  return emitOpError("unsigned flags are not supported for float types");
411  }
412  return success();
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // MFMAOp
417 //===----------------------------------------------------------------------===//
418 LogicalResult MFMAOp::verify() {
419  constexpr uint32_t waveSize = 64;
420  Builder b(getContext());
421 
422  Type sourceType = getSourceA().getType();
423  Type destType = getDestC().getType();
424 
425  Type sourceElem = sourceType, destElem = destType;
426  uint32_t sourceLen = 1, destLen = 1;
427  if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
428  sourceLen = sourceVector.getNumElements();
429  sourceElem = sourceVector.getElementType();
430  }
431  if (auto destVector = dyn_cast<VectorType>(destType)) {
432  destLen = destVector.getNumElements();
433  destElem = destVector.getElementType();
434  }
435 
436  Type sourceBType = getSourceB().getType();
437  if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
438  int64_t sourceBLen = 1;
439  Type sourceBElem = sourceBType;
440  if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
441  sourceBLen = sourceBVector.getNumElements();
442  sourceBElem = sourceBVector.getElementType();
443  }
444  if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
445  !sourceBElem.isFloat(4))
446  return emitOpError("expected both source operands to have small-float "
447  "elements if one does");
448  if (sourceLen != sourceBLen)
449  return emitOpError(
450  "expected both small-float source vectors to have the same length");
451  } else {
452  if (sourceType != sourceBType)
453  return emitOpError("expected both non-small-float source operand types "
454  "to match exactly");
455  }
456  // Normalize the wider integer types the compiler expects to i8.
457  if (sourceElem.isInteger(32)) {
458  sourceLen *= 4;
459  sourceElem = b.getI8Type();
460  }
461  if (sourceElem.isInteger(64)) {
462  sourceLen *= 8;
463  sourceElem = b.getI8Type();
464  }
465 
466  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
467  if (sourceLen != numSourceElems)
468  return emitOpError("expected " + Twine(numSourceElems) +
469  " source values for this operation but got " +
470  Twine(sourceLen));
471 
472  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
473  if (destLen != numDestElems)
474  return emitOpError("expected " + Twine(numDestElems) +
475  " result values for this operation but got " +
476  Twine(destLen));
477 
478  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
479  return emitOpError(
480  "double-precision ops do not support permuting lanes of B");
481  if (destElem.isF64() && getCbsz() != 0)
482  return emitOpError(
483  "double-precision ops do not support permuting lanes of A");
484  if (getAbid() >= (1u << getCbsz()))
485  return emitOpError(
486  "block ID for permuting A (abid) must be below 2 ** cbsz");
487 
488  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
489  return emitOpError(
490  "negation flags only available for double-precision operations");
491 
492  return success();
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // DPPOp
497 //===----------------------------------------------------------------------===//
498 LogicalResult DPPOp::verify() {
499  Type srcType = getSrc().getType();
500  if (srcType.getIntOrFloatBitWidth() > 64) {
501  return emitOpError("integer and floating point types larger than 64 bits "
502  "are not supported");
503  }
504 
505  DPPPerm kind = getKind();
506  Attribute permArgument = getPermArgument().value_or(Attribute{});
507 
508  switch (kind) {
509 
510  case DPPPerm::quad_perm: {
511  auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
512  if (!quadPermAttr || quadPermAttr.size() != 4) {
513  return emitOpError("quad_perm attribute must have exactly 4 elements");
514  }
515  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
516  int32_t num = elem.getInt();
517  if (num < 0 || num > 3) {
518  return emitOpError(
519  "Each element of quad_perm must be in the range [0, 3]");
520  }
521  }
522  } break;
523 
524  case DPPPerm::row_shl:
525  case DPPPerm::row_shr:
526  case DPPPerm::row_ror: {
527  if (!permArgument) {
528  return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
529  "' value not specified");
530  }
531  if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
532  uint32_t attrValue = intAttr.getInt();
533  if (attrValue < 1 || attrValue > 15) {
534  return emitOpError("Attribute value must be between 1 and 15");
535  }
536  }
537  } break;
538 
539  case DPPPerm::wave_shl:
540  case DPPPerm::wave_shr:
541  case DPPPerm::wave_rol:
542  case DPPPerm::wave_ror:
543  case DPPPerm::row_mirror:
544  case DPPPerm::row_half_mirror:
545  case DPPPerm::row_bcast_15:
546  case DPPPerm::row_bcast_31: {
547  if (permArgument && !isa<UnitAttr>(permArgument)) {
548  return emitOpError("Expected unit attribute for permArgument, but found "
549  "non-trivial argument");
550  }
551  break;
552  }
553  }
554  return success();
555 }
556 
557 //===----------------------------------------------------------------------===//
558 // PermlaneSwapOp
559 //===----------------------------------------------------------------------===//
560 LogicalResult PermlaneSwapOp::verify() {
561  unsigned rowLength = getRowLength();
562 
563  if (rowLength != 16 && rowLength != 32)
564  return emitOpError("row_length attribute must either be 16 or 32.");
565 
566  return success();
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // GatherToLDSOp
571 //===----------------------------------------------------------------------===//
572 
573 LogicalResult GatherToLDSOp::verify() {
574  MemRefType srcType = cast<MemRefType>(getSrc().getType());
575  MemRefType dstType = cast<MemRefType>(getDst().getType());
576 
577  if (!dstType.areTrailingDimsContiguous(1))
578  return emitOpError("destination type inner most dim must be contiguous");
579 
580  auto elemType = srcType.getElementType();
581  // Check $src and $dst element types are the same.
582  if (elemType != dstType.getElementType())
583  return emitOpError("source and destination element types must match");
584 
585  // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
586  auto transferType = getTransferType();
587  int transferSize;
588  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
589  transferSize = vectorTransfer.getNumElements() *
590  vectorTransfer.getElementTypeBitWidth();
591  } else {
592  transferSize = transferType.getIntOrFloatBitWidth();
593  }
594  if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
595  return emitOpError(
596  "Transfering type size must be 8, 16, 32, 96 or 128 bits");
597 
598  if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
599  !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
600  return emitOpError(
601  "source memory address space must be global or fat raw buffer");
602 
603  if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
604  return emitOpError("destination memory address space must be Workgroup");
605 
606  return success();
607 }
608 
609 namespace {
610 /// If the source/target of a GatherToLDSOp is a CastOp that only removes static
611 /// information or changes layout, the cast can be skipped.
612 struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
614 
615  LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
616  PatternRewriter &rewriter) const override {
617  bool modified = false;
618  auto foldCast = [&](OpOperand &operand) {
619  if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
621  rewriter.modifyOpInPlace(gatherOp,
622  [&] { operand.assign(castOp.getSource()); });
623  modified = true;
624  }
625  }
626  };
627 
628  foldCast(gatherOp.getSrcMutable());
629  foldCast(gatherOp.getDstMutable());
630 
631  return success(modified);
632  }
633 };
634 } // namespace
635 
636 void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
637  MLIRContext *context) {
638  results.add<FoldGatherToLDSOfCast>(context);
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // TransposeLoadOp
643 //===----------------------------------------------------------------------===//
644 
645 LogicalResult TransposeLoadOp::verify() {
646  MemRefType srcType = cast<MemRefType>(getSrc().getType());
647 
648  if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
649  return emitOpError("source memory address space must be Workgroup");
650 
651  auto transferType = cast<VectorType>(getType());
652  size_t numElements = transferType.getNumElements();
653  size_t elementTypeSize =
654  transferType.getElementType().getIntOrFloatBitWidth();
655 
656  // ElementSize -> NumElements
657  const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
658  {4, 16},
659  {6, 16},
660  {8, 8},
661  {16, 4},
662  };
663 
664  auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
665  if (validNumElems == kValidLoadSizeMap.end()) {
666  return emitOpError("Unsupported element type size for transpose load: ")
667  << elementTypeSize << " bits";
668  }
669  if (numElements != validNumElems->second) {
670  return emitOpError(
671  "Transferring type size mismatch: expected num of elements: ")
672  << validNumElems->second;
673  }
674 
675  return success();
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // ScaledMFMAOp
680 //===----------------------------------------------------------------------===//
681 
682 namespace {
683 /// Check if the scales input is used in other scaled mfma's while they exist.
684 /// If theyre unused then pack the scales.
685 struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
687 
688  LogicalResult matchAndRewrite(ScaledMFMAOp op,
689  PatternRewriter &rewriter) const override {
690  Location loc = op.getLoc();
691  auto setOpsel = [&op](unsigned idx, int64_t val) {
692  switch (idx) {
693  case 3:
694  op.setScalesIdxA(val);
695  break;
696  case 4:
697  op.setScalesIdxB(val);
698  break;
699  default:
700  break;
701  }
702  };
703 
704  // For every scale operand of this ScaledMFMAOp, if the scale is produced by
705  // the extraction of a single scale from some vector, then attempt to
706  // extract 4 values from that vector instead.
707  //
708  // Example: (f8 here means f8E8M0FNU)
709  // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
710  // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
711  // amdgpu.scaled_mfma(%scale[0] * ...
712  //
713  // rewrite to:
714  //
715  // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
716  // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
717  // amdgpu.scaled_mfma(%scale[0-3] * ...
718  //
719  // This creates duplicate shape_casts for every use but these will be
720  // removed in CSE.
721  for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
722  auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
723  if (!insertOp) {
724  return rewriter.notifyMatchFailure(op,
725  "defining op not a vector.insert");
726  }
727  // If the extracted value is not a single scalar, then it has been packed.
728  if (isa<VectorType>(insertOp.getValueToStore().getType())) {
729  return rewriter.notifyMatchFailure(
730  op, "scaled mfma operand already packed");
731  }
732 
733  auto extractOp =
734  insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
735  if (!extractOp) {
736  return rewriter.notifyMatchFailure(op,
737  "defining op not a vector.extract");
738  }
739 
740  Value scaleSrc = extractOp.getOperand(0);
741  auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
742  if (!scaleSrcType) {
743  return rewriter.notifyMatchFailure(op, "not a vector type");
744  }
745 
746  // We do not handle dynamic dims yet, assume that the input is padded to
747  // a static shape now.
748  if (!scaleSrcType.hasStaticShape()) {
749  return rewriter.notifyMatchFailure(op,
750  "dynamic dims not yet supported");
751  }
752 
753  int64_t numElements = scaleSrcType.getNumElements();
754  if (numElements <= 4) {
755  return rewriter.notifyMatchFailure(
756  op, "no packing if # of scales less than four");
757  }
758 
759  // Find a linearized idx using the size and offsets of the extract op.
760  auto extractedPos = llvm::to_vector_of<int64_t>(
761  llvm::reverse(extractOp.getStaticPosition()));
762  ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
763  int64_t scaleSrcRank = scaleSrcType.getRank();
764  SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
765  for (int64_t i = 1; i < scaleSrcRank; ++i) {
766  extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
767  }
768  int64_t idx = linearize(extractedPos, extractSizes);
769 
770  // All n scales (where n is the total number of scales) must now be
771  // extracted in chunks of 4 elements. This is done by dividing the
772  // original vector of scales into groups of 4 elements
773  // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
774  // scale at a particular index are now replaced with an extraction
775  // of the entire group of 4 elements to which that index belongs.
776  //
777  // If the number of scales happens to be indivisible by 4, extract
778  // the remaining n - m scales in a chunk of 4 elements starting at
779  // offset n - 4.
780  int64_t offset = idx - (idx % 4);
781  int64_t opsel = idx - offset;
782  int64_t size = 4l;
783  // Accomdate remaining elements in the case of non-4-divisible vectors.
784  if (numElements - offset < size) {
785  opsel = size - (numElements - idx);
786  offset = numElements - 4l;
787  }
788  Type scaleSrcElemType = scaleSrcType.getElementType();
789  auto newSrcType =
790  VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
791  Value newScaleSrc =
792  vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
793  auto extract = vector::ExtractStridedSliceOp::create(
794  rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
795  ArrayRef{int64_t(1)});
796  rewriter.modifyOpInPlace(op, [&] {
797  op->setOperand(opIdx, extract);
798  setOpsel(opIdx, opsel);
799  });
800  }
801  return success();
802  }
803 };
804 } // namespace
805 
806 void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
807  MLIRContext *context) {
808  results.add<PackScales>(context);
809 }
810 
811 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
812 
813 #define GET_ATTRDEF_CLASSES
814 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
815 
816 #define GET_OP_CLASSES
817 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1257::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class represents an operand of an operation.
Definition: Value.h:257
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition: Types.cpp:45
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:318
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322