MLIR  22.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <algorithm>
34 #include <cstdint>
35 #include <limits>
36 #include <optional>
37 
38 using namespace mlir;
39 using namespace mlir::amdgpu;
40 
41 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
42 
43 void AMDGPUDialect::initialize() {
44  addOperations<
45 #define GET_OP_LIST
46 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
47  >();
48  addAttributes<
49 #define GET_ATTRDEF_LIST
50 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
51  >();
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // 8-bit float ops
56 //===----------------------------------------------------------------------===//
57 LogicalResult PackedTrunc2xFp8Op::verify() {
58  if (getExisting() && getExisting().getType() != getResult().getType())
59  return emitOpError("existing values must have same type as result");
60  return success();
61 }
62 
63 LogicalResult PackedStochRoundFp8Op::verify() {
64  if (getExisting() && getExisting().getType() != getResult().getType())
65  return emitOpError("existing values must have same type as result");
66  return success();
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // mxfp float ops
71 //===----------------------------------------------------------------------===//
72 LogicalResult PackedScaledTruncOp::verify() {
73  if (getExisting() && getExisting().getType() != getResult().getType())
74  return emitOpError("existing values must have same type as result");
75  return success();
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // FatRawBufferCastOp
80 //===----------------------------------------------------------------------===//
81 
82 /// Convert the type `source` to one with the same sizes and strides - and
83 /// offset, unless `stripOffset` is true, in which case the offset is reset to
84 /// 0, if the offset should be reset but the layout of `source` isn't either the
85 /// identity layout or a strided layout, this function fails.
86 static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
87  bool resetOffset) {
88  MLIRContext *ctx = source.getContext();
89  MemRefType::Builder mb(source);
90  mb.setMemorySpace(
91  amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
92  MemRefLayoutAttrInterface layout = source.getLayout();
93  if (resetOffset && !layout.isIdentity()) {
94  auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
95  if (!stridedLayout)
96  return failure();
97  MemRefLayoutAttrInterface newLayout =
98  StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
99  // Special case: if resetting the offset causes the strided layout to become
100  // the identity layout, then reset to the identity layout.
101  // TODO: this'll get a lot simpler when we have the contiguous layout.
102  SmallVector<int64_t> stridesIfIdentity;
103  if (source.hasStaticShape()) {
104  stridesIfIdentity = computeSuffixProduct(source.getShape());
105  } else if (source.getRank() <= 1) {
106  stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
107  }
108  if (stridesIfIdentity == stridedLayout.getStrides()) {
109  newLayout = AffineMapAttr::get(
110  AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
111  }
112  mb.setLayout(newLayout);
113  }
114  return (MemRefType)(mb);
115 }
116 
117 LogicalResult FatRawBufferCastOp::inferReturnTypes(
118  MLIRContext *context, std::optional<Location> location, ValueRange operands,
119  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
120  SmallVectorImpl<Type> &inferredReturnTypes) {
121  Adaptor adaptor(operands, attributes, properties, regions);
122  auto sourceType =
123  dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
124  if (!sourceType)
125  return failure();
126  FailureOr<MemRefType> resultType =
127  getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
128  if (failed(resultType))
129  return failure();
130  inferredReturnTypes = SmallVector<Type>{*resultType};
131  return success();
132 }
133 
134 LogicalResult FatRawBufferCastOp::verify() {
135  FailureOr<MemRefType> expectedResultType =
136  getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
137  if (failed(expectedResultType))
138  return emitOpError("source type ")
139  << getSource().getType() << " can't have its offset reset";
140  if (getResult().getType() != *expectedResultType)
141  return emitOpError("expected result type to be ")
142  << *expectedResultType << " but got " << getResult().getType();
143  return success();
144 }
145 
146 static bool hasGlobalMemorySpace(Attribute memorySpace) {
147  if (!memorySpace)
148  return true;
149  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
150  return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
151  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
152  return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
153  return false;
154 }
155 
156 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
157  if (!memorySpace)
158  return false;
159  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
160  return intMemorySpace.getInt() == 3;
161  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
162  return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
163  return false;
164 }
165 
166 static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
167  if (!memorySpace)
168  return false;
169  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
170  return intMemorySpace.getInt() == 7;
171  if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
172  return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
173  return false;
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // RawBuffer*Op
178 //===----------------------------------------------------------------------===//
179 template <typename T>
180 static LogicalResult verifyRawBufferOp(T &op) {
181  MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
182  bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
183 
184  if (!isGlobal)
185  return op.emitOpError(
186  "Buffer ops must operate on a memref in global memory");
187  if (!bufferType.hasRank())
188  return op.emitOpError(
189  "Cannot meaningfully buffer_store to an unranked memref");
190  if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
191  return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
192  " indices to memref");
193  return success();
194 }
195 
196 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
197 
198 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
199 
200 LogicalResult RawBufferAtomicFaddOp::verify() {
201  return verifyRawBufferOp(*this);
202 }
203 
204 LogicalResult RawBufferAtomicFmaxOp::verify() {
205  return verifyRawBufferOp(*this);
206 }
207 
208 LogicalResult RawBufferAtomicSmaxOp::verify() {
209  return verifyRawBufferOp(*this);
210 }
211 
212 LogicalResult RawBufferAtomicUminOp::verify() {
213  return verifyRawBufferOp(*this);
214 }
215 
216 LogicalResult RawBufferAtomicCmpswapOp::verify() {
217  return verifyRawBufferOp(*this);
218 }
219 
220 static std::optional<uint32_t> getConstantUint32(Value v) {
221  APInt cst;
222  if (!v.getType().isInteger(32))
223  return std::nullopt;
224  if (matchPattern(v, m_ConstantInt(&cst)))
225  return cst.getZExtValue();
226  return std::nullopt;
227 }
228 
229 template <typename OpType>
230 static bool staticallyOutOfBounds(OpType op) {
231  if (!op.getBoundsCheck())
232  return false;
233  MemRefType bufferType = op.getMemref().getType();
234  if (!bufferType.hasStaticShape())
235  return false;
236  int64_t offset;
237  SmallVector<int64_t> strides;
238  if (failed(bufferType.getStridesAndOffset(strides, offset)))
239  return false;
240  int64_t result = offset + op.getIndexOffset().value_or(0);
241  if (op.getSgprOffset()) {
242  std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
243  if (!sgprOffset)
244  return false;
245  result += *sgprOffset;
246  }
247  if (strides.size() != op.getIndices().size())
248  return false;
249  int64_t indexVal = 0;
250  for (auto pair : llvm::zip(strides, op.getIndices())) {
251  int64_t stride = std::get<0>(pair);
252  Value idx = std::get<1>(pair);
253  std::optional<uint32_t> idxVal = getConstantUint32(idx);
254  if (!idxVal)
255  return false;
256  indexVal += stride * *idxVal;
257  }
258  result += indexVal;
259  if (result > std::numeric_limits<uint32_t>::max())
260  // Overflow means don't drop
261  return false;
262  return result >= bufferType.getNumElements();
263 }
264 
265 namespace {
266 template <typename OpType>
267 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
269 
270  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
271  if (!staticallyOutOfBounds(op))
272  return failure();
273  Type loadType = op.getResult().getType();
274  rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
275  rw.getZeroAttr(loadType));
276  return success();
277  }
278 };
279 
280 template <typename OpType>
281 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
283 
284  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
285  if (!staticallyOutOfBounds(op))
286  return failure();
287 
288  rw.eraseOp(op);
289  return success();
290  }
291 };
292 } // end namespace
293 
294 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
295  MLIRContext *context) {
296  results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
297 }
298 
299 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
300  MLIRContext *context) {
301  results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
302 }
303 
304 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
305  RewritePatternSet &results, MLIRContext *context) {
306  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
307 }
308 
309 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
310  RewritePatternSet &results, MLIRContext *context) {
311  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
312 }
313 
314 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
315  RewritePatternSet &results, MLIRContext *context) {
316  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
317 }
318 
319 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
320  RewritePatternSet &results, MLIRContext *context) {
321  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
322 }
323 
324 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
325  RewritePatternSet &results, MLIRContext *context) {
326  results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
327  context);
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // WMMAOp
332 //===----------------------------------------------------------------------===//
333 LogicalResult WMMAOp::verify() {
334  Type sourceAType = getSourceA().getType();
335  Type sourceBType = getSourceB().getType();
336  Type destType = getDestC().getType();
337 
338  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
339  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
340  VectorType destVectorType = dyn_cast<VectorType>(destType);
341 
342  Type sourceAElemType = sourceVectorAType.getElementType();
343  Type sourceBElemType = sourceVectorBType.getElementType();
344  Type destElemType = destVectorType.getElementType();
345 
346  if (sourceVectorAType.getNumElements() !=
347  sourceVectorBType.getNumElements()) {
348  return emitOpError("source vectors have different lengths: ")
349  << sourceVectorAType << " vs. " << sourceVectorBType;
350  }
351 
352  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
353  bool isSrcFloat =
354  isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
355  sourceAElemType);
356 
357  if (isDestFloat && !isSrcFloat) {
358  return emitOpError("Expected float sources with float destination");
359  }
360 
361  if (!isDestFloat && isSrcFloat) {
362  return emitOpError("Expected int sources with int destination");
363  }
364 
365  if (sourceAElemType != sourceBElemType &&
366  !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
367  isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
368  return emitOpError(
369  "source element types much match (except for fp8) but have ")
370  << sourceAType << " and " << sourceBType;
371  }
372  return success();
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // MFMAOp
377 //===----------------------------------------------------------------------===//
378 LogicalResult MFMAOp::verify() {
379  constexpr uint32_t waveSize = 64;
380  Builder b(getContext());
381 
382  Type sourceType = getSourceA().getType();
383  Type destType = getDestC().getType();
384 
385  Type sourceElem = sourceType, destElem = destType;
386  uint32_t sourceLen = 1, destLen = 1;
387  if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
388  sourceLen = sourceVector.getNumElements();
389  sourceElem = sourceVector.getElementType();
390  }
391  if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
392  destLen = destVector.getNumElements();
393  destElem = destVector.getElementType();
394  }
395 
396  Type sourceBType = getSourceB().getType();
397  if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
398  int64_t sourceBLen = 1;
399  Type sourceBElem = sourceBType;
400  if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
401  sourceBLen = sourceBVector.getNumElements();
402  sourceBElem = sourceBVector.getElementType();
403  }
404  if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
405  !sourceBElem.isFloat(4))
406  return emitOpError("expected both source operands to have small-float "
407  "elements if one does");
408  if (sourceLen != sourceBLen)
409  return emitOpError(
410  "expected both small-float source vectors to have the same length");
411  } else {
412  if (sourceType != sourceBType)
413  return emitOpError("expected both non-small-float source operand types "
414  "to match exactly");
415  }
416  // Normalize the wider integer types the compiler expects to i8
417  if (sourceElem.isInteger(32)) {
418  sourceLen *= 4;
419  sourceElem = b.getI8Type();
420  }
421  if (sourceElem.isInteger(64)) {
422  sourceLen *= 8;
423  sourceElem = b.getI8Type();
424  }
425 
426  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
427  if (sourceLen != numSourceElems)
428  return emitOpError("expected " + Twine(numSourceElems) +
429  " source values for this operation but got " +
430  Twine(sourceLen));
431 
432  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
433  if (destLen != numDestElems)
434  return emitOpError("expected " + Twine(numDestElems) +
435  " result values for this operation but got " +
436  Twine(destLen));
437 
438  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
439  return emitOpError(
440  "double-precision ops do not support permuting lanes of B");
441  if (destElem.isF64() && getCbsz() != 0)
442  return emitOpError(
443  "double-precision ops do not support permuting lanes of A");
444  if (getAbid() >= (1u << getCbsz()))
445  return emitOpError(
446  "block ID for permuting A (abid) must be below 2 ** cbsz");
447 
448  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
449  return emitOpError(
450  "negation flags only available for double-precision operations");
451 
452  return success();
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // DPPOp
457 //===----------------------------------------------------------------------===//
458 LogicalResult DPPOp::verify() {
459  Type srcType = getSrc().getType();
460  if (srcType.getIntOrFloatBitWidth() > 64) {
461  return emitOpError("integer and floating point types larger than 64 bits "
462  "are not supported");
463  }
464 
465  DPPPerm kind = getKind();
466  Attribute permArgument = getPermArgument().value_or(Attribute{});
467 
468  switch (kind) {
469 
470  case DPPPerm::quad_perm: {
471  auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
472  if (!quadPermAttr || quadPermAttr.size() != 4) {
473  return emitOpError("quad_perm attribute must have exactly 4 elements");
474  }
475  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
476  int32_t num = elem.getInt();
477  if (num < 0 || num > 3) {
478  return emitOpError(
479  "Each element of quad_perm must be in the range [0, 3]");
480  }
481  }
482  } break;
483 
484  case DPPPerm::row_shl:
485  case DPPPerm::row_shr:
486  case DPPPerm::row_ror: {
487  if (!permArgument) {
488  return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
489  "' value not specified");
490  }
491  if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
492  uint32_t attrValue = intAttr.getInt();
493  if (attrValue < 1 || attrValue > 15) {
494  return emitOpError("Attribute value must be between 1 and 15");
495  }
496  }
497  } break;
498 
499  case DPPPerm::wave_shl:
500  case DPPPerm::wave_shr:
501  case DPPPerm::wave_rol:
502  case DPPPerm::wave_ror:
503  case DPPPerm::row_mirror:
504  case DPPPerm::row_half_mirror:
505  case DPPPerm::row_bcast_15:
506  case DPPPerm::row_bcast_31: {
507  if (permArgument && !isa<UnitAttr>(permArgument)) {
508  return emitOpError("Expected unit attribute for permArgument, but found "
509  "non-trivial argument");
510  }
511  break;
512  }
513  }
514  return success();
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // PermlaneSwapOp
519 //===----------------------------------------------------------------------===//
520 LogicalResult PermlaneSwapOp::verify() {
521  unsigned rowLength = getRowLength();
522 
523  if (rowLength != 16 && rowLength != 32)
524  return emitOpError("row_length attribute must either be 16 or 32.");
525 
526  return success();
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // GatherToLDSOp
531 //===----------------------------------------------------------------------===//
532 
533 LogicalResult GatherToLDSOp::verify() {
534  MemRefType srcType = cast<MemRefType>(getSrc().getType());
535  MemRefType dstType = cast<MemRefType>(getDst().getType());
536 
537  if (!dstType.areTrailingDimsContiguous(1))
538  return emitOpError("destination type inner most dim must be contiguous");
539 
540  auto elemType = srcType.getElementType();
541  // Check $src and $dst element types are the same.
542  if (elemType != dstType.getElementType())
543  return emitOpError("source and destination element types must match");
544 
545  // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
546  auto transferType = getTransferType();
547  int transferSize;
548  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
549  transferSize = vectorTransfer.getNumElements() *
550  vectorTransfer.getElementTypeBitWidth();
551  } else {
552  transferSize = transferType.getIntOrFloatBitWidth();
553  }
554  if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
555  return emitOpError(
556  "Transfering type size must be 8, 16, 32, 96 or 128 bits");
557 
558  if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
559  !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
560  return emitOpError(
561  "source memory address space must be global or fat raw buffer");
562 
563  if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
564  return emitOpError("destination memory address space must be Workgroup");
565 
566  return success();
567 }
568 
569 namespace {
570 /// If the source/target of a GatherToLDSOp is a CastOp that only removes static
571 /// information or changes layout, the cast can be skipped.
572 struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
574 
575  LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
576  PatternRewriter &rewriter) const override {
577  bool modified = false;
578  auto foldCast = [&](OpOperand &operand) {
579  if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
581  rewriter.modifyOpInPlace(gatherOp,
582  [&] { operand.assign(castOp.getSource()); });
583  modified = true;
584  }
585  }
586  };
587 
588  foldCast(gatherOp.getSrcMutable());
589  foldCast(gatherOp.getDstMutable());
590 
591  return success(modified);
592  }
593 };
594 } // namespace
595 
596 void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
597  MLIRContext *context) {
598  results.add<FoldGatherToLDSOfCast>(context);
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // TransposeLoadOp
603 //===----------------------------------------------------------------------===//
604 
605 LogicalResult TransposeLoadOp::verify() {
606  MemRefType srcType = cast<MemRefType>(getSrc().getType());
607 
608  if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
609  return emitOpError("source memory address space must be Workgroup");
610 
611  auto transferType = cast<VectorType>(getType());
612  size_t numElements = transferType.getNumElements();
613  size_t elementTypeSize =
614  transferType.getElementType().getIntOrFloatBitWidth();
615 
616  // ElementSize -> NumElements
617  const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
618  {4, 16},
619  {6, 16},
620  {8, 8},
621  {16, 4},
622  };
623 
624  auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
625  if (validNumElems == kValidLoadSizeMap.end()) {
626  return emitOpError("Unsupported element type size for transpose load: ")
627  << elementTypeSize << " bits";
628  }
629  if (numElements != validNumElems->second) {
630  return emitOpError(
631  "Transferring type size mismatch: expected num of elements: ")
632  << validNumElems->second;
633  }
634 
635  return success();
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // ScaledMFMAOp
640 //===----------------------------------------------------------------------===//
641 
642 namespace {
643 /// Check if the scales input is used in other scaled mfma's while they exist.
644 /// If theyre unused then pack the scales.
645 struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
647 
648  LogicalResult matchAndRewrite(ScaledMFMAOp op,
649  PatternRewriter &rewriter) const override {
650  Location loc = op.getLoc();
651  auto setOpsel = [&op](unsigned idx, int64_t val) {
652  switch (idx) {
653  case 3:
654  op.setScalesIdxA(val);
655  break;
656  case 4:
657  op.setScalesIdxB(val);
658  break;
659  default:
660  break;
661  }
662  };
663 
664  // For every scale operand of this ScaledMFMAOp, if the scale is produced by
665  // the extraction of a single scale from some vector, then attempt to
666  // extract 4 values from that vector instead.
667  //
668  // Example: (f8 here means f8E8M0FNU)
669  // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
670  // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
671  // amdgpu.scaled_mfma(%scale[0] * ...
672  //
673  // rewrite to:
674  //
675  // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
676  // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
677  // amdgpu.scaled_mfma(%scale[0-3] * ...
678  //
679  // This creates duplicate shape_casts for every use but these will be
680  // removed in CSE.
681  for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
682  auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
683  if (!insertOp) {
684  return rewriter.notifyMatchFailure(op,
685  "defining op not a vector.insert");
686  }
687  // If the extracted value is not a single scalar, then it has been packed.
688  if (isa<VectorType>(insertOp.getValueToStore().getType())) {
689  return rewriter.notifyMatchFailure(
690  op, "scaled mfma operand already packed");
691  }
692 
693  auto extractOp =
694  insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
695  if (!extractOp) {
696  return rewriter.notifyMatchFailure(op,
697  "defining op not a vector.extract");
698  }
699 
700  Value scaleSrc = extractOp.getOperand(0);
701  auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
702  if (!scaleSrcType) {
703  return rewriter.notifyMatchFailure(op, "not a vector type");
704  }
705 
706  // We do not handle dynamic dims yet, assume that the input is padded to
707  // a static shape now.
708  if (!scaleSrcType.hasStaticShape()) {
709  return rewriter.notifyMatchFailure(op,
710  "dynamic dims not yet supported");
711  }
712 
713  int64_t numElements = scaleSrcType.getNumElements();
714  if (numElements <= 4) {
715  return rewriter.notifyMatchFailure(
716  op, "no packing if # of scales less than four");
717  }
718 
719  // Find a linearized idx using the size and offsets of the extract op.
720  auto extractedPos = llvm::to_vector_of<int64_t>(
721  llvm::reverse(extractOp.getStaticPosition()));
722  ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
723  int64_t scaleSrcRank = scaleSrcType.getRank();
724  SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
725  for (int64_t i = 1; i < scaleSrcRank; ++i) {
726  extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
727  }
728  int64_t idx = linearize(extractedPos, extractSizes);
729 
730  // All n scales (where n is the total number of scales) must now be
731  // extracted in chunks of 4 elements. This is done by dividing the
732  // original vector of scales into groups of 4 elements
733  // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
734  // scale at a particular index are now replaced with an extraction
735  // of the entire group of 4 elements to which that index belongs.
736  //
737  // If the number of scales happens to be indivisible by 4, extract
738  // the remaining n - m scales in a chunk of 4 elements starting at
739  // offset n - 4.
740  int64_t offset = idx - (idx % 4);
741  int64_t opsel = idx - offset;
742  int64_t size = 4l;
743  // Accomdate remaining elements in the case of non-4-divisible vectors.
744  if (numElements - offset < size) {
745  opsel = size - (numElements - idx);
746  offset = numElements - 4l;
747  }
748  Type scaleSrcElemType = scaleSrcType.getElementType();
749  auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
750  scaleSrcElemType);
751  Value newScaleSrc =
752  vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
753  auto extract = vector::ExtractStridedSliceOp::create(
754  rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset},
756  rewriter.modifyOpInPlace(op, [&] {
757  op->setOperand(opIdx, extract);
758  setOpsel(opIdx, opsel);
759  });
760  }
761  return success();
762  }
763 };
764 } // namespace
765 
766 void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
767  MLIRContext *context) {
768  results.add<PackScales>(context);
769 }
770 
771 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
772 
773 #define GET_ATTRDEF_CLASSES
774 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
775 
776 #define GET_OP_CLASSES
777 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
This class represents an operand of an operation.
Definition: Value.h:257
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition: Types.cpp:45
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322