MLIR  22.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 
34 #include <algorithm>
35 #include <cstdint>
36 #include <limits>
37 #include <optional>
38 
39 using namespace mlir;
40 using namespace mlir::amdgpu;
41 
42 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
43 
44 namespace {
45 struct AMDGPUInlinerInterface final : DialectInlinerInterface {
47  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
48  return true;
49  }
50 };
51 } // namespace
52 
53 void AMDGPUDialect::initialize() {
54  addOperations<
55 #define GET_OP_LIST
56 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
57  >();
58  addAttributes<
59 #define GET_ATTRDEF_LIST
60 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
61  >();
62  addInterfaces<AMDGPUInlinerInterface>();
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // 8-bit float ops
67 //===----------------------------------------------------------------------===//
68 LogicalResult PackedTrunc2xFp8Op::verify() {
69  if (getExisting() && getExisting().getType() != getResult().getType())
70  return emitOpError("existing values must have same type as result");
71  return success();
72 }
73 
74 LogicalResult PackedStochRoundFp8Op::verify() {
75  if (getExisting() && getExisting().getType() != getResult().getType())
76  return emitOpError("existing values must have same type as result");
77  return success();
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // mxfp float ops
82 //===----------------------------------------------------------------------===//
83 LogicalResult PackedScaledTruncOp::verify() {
84  if (getExisting() && getExisting().getType() != getResult().getType())
85  return emitOpError("existing values must have same type as result");
86  return success();
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // FatRawBufferCastOp
91 //===----------------------------------------------------------------------===//
92 
93 /// Convert the type `source` to one with the same sizes and strides - and
94 /// offset, unless `stripOffset` is true, in which case the offset is reset to
95 /// 0, if the offset should be reset but the layout of `source` isn't either the
96 /// identity layout or a strided layout, this function fails.
97 static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
98  bool resetOffset) {
99  MLIRContext *ctx = source.getContext();
100  MemRefType::Builder mb(source);
101  mb.setMemorySpace(
102  amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
103  MemRefLayoutAttrInterface layout = source.getLayout();
104  if (resetOffset && !layout.isIdentity()) {
105  auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
106  if (!stridedLayout)
107  return failure();
108  MemRefLayoutAttrInterface newLayout =
109  StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
110  // Special case: if resetting the offset causes the strided layout to become
111  // the identity layout, then reset to the identity layout.
112  // TODO: this'll get a lot simpler when we have the contiguous layout.
113  SmallVector<int64_t> stridesIfIdentity;
114  if (source.hasStaticShape()) {
115  stridesIfIdentity = computeSuffixProduct(source.getShape());
116  } else if (source.getRank() <= 1) {
117  stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
118  }
119  if (stridesIfIdentity == stridedLayout.getStrides()) {
120  newLayout = AffineMapAttr::get(
121  AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
122  }
123  mb.setLayout(newLayout);
124  }
125  return (MemRefType)(mb);
126 }
127 
128 LogicalResult FatRawBufferCastOp::inferReturnTypes(
129  MLIRContext *context, std::optional<Location> location, ValueRange operands,
130  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
131  SmallVectorImpl<Type> &inferredReturnTypes) {
132  Adaptor adaptor(operands, attributes, properties, regions);
133  auto sourceType =
134  dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
135  if (!sourceType)
136  return failure();
137  FailureOr<MemRefType> resultType =
138  getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
139  if (failed(resultType))
140  return failure();
141  inferredReturnTypes = SmallVector<Type>{*resultType};
142  return success();
143 }
144 
145 LogicalResult FatRawBufferCastOp::verify() {
146  FailureOr<MemRefType> expectedResultType =
147  getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
148  if (failed(expectedResultType))
149  return emitOpError("source type ")
150  << getSource().getType() << " can't have its offset reset";
151  if (getResult().getType() != *expectedResultType)
152  return emitOpError("expected result type to be ")
153  << *expectedResultType << " but got " << getResult().getType();
154  return success();
155 }
156 
157 static bool hasGlobalMemorySpace(Attribute memorySpace) {
158  if (!memorySpace)
159  return true;
160  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
161  return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
162  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
163  return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
164  return false;
165 }
166 
167 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
168  if (!memorySpace)
169  return false;
170  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
171  return intMemorySpace.getInt() == 3;
172  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
173  return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
174  return false;
175 }
176 
177 static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
178  if (!memorySpace)
179  return false;
180  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
181  return intMemorySpace.getInt() == 7;
182  if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
183  return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
184  return false;
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // RawBuffer*Op
189 //===----------------------------------------------------------------------===//
190 template <typename T>
191 static LogicalResult verifyRawBufferOp(T &op) {
192  MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
193  bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
194 
195  if (!isGlobal)
196  return op.emitOpError(
197  "Buffer ops must operate on a memref in global memory");
198  if (!bufferType.hasRank())
199  return op.emitOpError(
200  "Cannot meaningfully buffer_store to an unranked memref");
201  if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
202  return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
203  " indices to memref");
204  return success();
205 }
206 
207 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
208 
209 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
210 
211 LogicalResult RawBufferAtomicFaddOp::verify() {
212  return verifyRawBufferOp(*this);
213 }
214 
215 LogicalResult RawBufferAtomicFmaxOp::verify() {
216  return verifyRawBufferOp(*this);
217 }
218 
219 LogicalResult RawBufferAtomicSmaxOp::verify() {
220  return verifyRawBufferOp(*this);
221 }
222 
223 LogicalResult RawBufferAtomicUminOp::verify() {
224  return verifyRawBufferOp(*this);
225 }
226 
227 LogicalResult RawBufferAtomicCmpswapOp::verify() {
228  return verifyRawBufferOp(*this);
229 }
230 
231 static std::optional<uint32_t> getConstantUint32(Value v) {
232  APInt cst;
233  if (!v.getType().isInteger(32))
234  return std::nullopt;
235  if (matchPattern(v, m_ConstantInt(&cst)))
236  return cst.getZExtValue();
237  return std::nullopt;
238 }
239 
240 template <typename OpType>
241 static bool staticallyOutOfBounds(OpType op) {
242  if (!op.getBoundsCheck())
243  return false;
244  MemRefType bufferType = op.getMemref().getType();
245  if (!bufferType.hasStaticShape())
246  return false;
247  int64_t offset;
248  SmallVector<int64_t> strides;
249  if (failed(bufferType.getStridesAndOffset(strides, offset)))
250  return false;
251  int64_t result = offset + op.getIndexOffset().value_or(0);
252  if (op.getSgprOffset()) {
253  std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
254  if (!sgprOffset)
255  return false;
256  result += *sgprOffset;
257  }
258  if (strides.size() != op.getIndices().size())
259  return false;
260  int64_t indexVal = 0;
261  for (auto pair : llvm::zip(strides, op.getIndices())) {
262  int64_t stride = std::get<0>(pair);
263  Value idx = std::get<1>(pair);
264  std::optional<uint32_t> idxVal = getConstantUint32(idx);
265  if (!idxVal)
266  return false;
267  indexVal += stride * *idxVal;
268  }
269  result += indexVal;
270  if (result > std::numeric_limits<uint32_t>::max())
271  // Overflow means don't drop
272  return false;
273  return result >= bufferType.getNumElements();
274 }
275 
276 namespace {
277 template <typename OpType>
278 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
280 
281  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
282  if (!staticallyOutOfBounds(op))
283  return failure();
284  Type loadType = op.getResult().getType();
285  rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
286  rw.getZeroAttr(loadType));
287  return success();
288  }
289 };
290 
291 template <typename OpType>
292 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
294 
295  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
296  if (!staticallyOutOfBounds(op))
297  return failure();
298 
299  rw.eraseOp(op);
300  return success();
301  }
302 };
303 } // end namespace
304 
305 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
306  MLIRContext *context) {
307  results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
308 }
309 
310 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
311  MLIRContext *context) {
312  results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
313 }
314 
315 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
316  RewritePatternSet &results, MLIRContext *context) {
317  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
318 }
319 
320 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
321  RewritePatternSet &results, MLIRContext *context) {
322  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
323 }
324 
325 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
326  RewritePatternSet &results, MLIRContext *context) {
327  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
328 }
329 
330 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
331  RewritePatternSet &results, MLIRContext *context) {
332  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
333 }
334 
335 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
336  RewritePatternSet &results, MLIRContext *context) {
337  results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
338  context);
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // WMMAOp
343 //===----------------------------------------------------------------------===//
344 LogicalResult WMMAOp::verify() {
345  Type sourceAType = getSourceA().getType();
346  Type sourceBType = getSourceB().getType();
347  Type destType = getDestC().getType();
348 
349  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
350  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
351  VectorType destVectorType = dyn_cast<VectorType>(destType);
352 
353  Type sourceAElemType = sourceVectorAType.getElementType();
354  Type sourceBElemType = sourceVectorBType.getElementType();
355  Type destElemType = destVectorType.getElementType();
356 
357  if (sourceVectorAType.getNumElements() !=
358  sourceVectorBType.getNumElements()) {
359  return emitOpError("source vectors have different lengths: ")
360  << sourceVectorAType << " vs. " << sourceVectorBType;
361  }
362 
363  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
364  bool isSrcFloat =
365  isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
366  sourceAElemType);
367 
368  if (isDestFloat && !isSrcFloat) {
369  return emitOpError("Expected float sources with float destination");
370  }
371 
372  if (!isDestFloat && isSrcFloat) {
373  return emitOpError("Expected int sources with int destination");
374  }
375 
376  if (sourceAElemType != sourceBElemType &&
377  !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
378  isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
379  return emitOpError(
380  "source element types much match (except for fp8) but have ")
381  << sourceAType << " and " << sourceBType;
382  }
383  return success();
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // MFMAOp
388 //===----------------------------------------------------------------------===//
389 LogicalResult MFMAOp::verify() {
390  constexpr uint32_t waveSize = 64;
391  Builder b(getContext());
392 
393  Type sourceType = getSourceA().getType();
394  Type destType = getDestC().getType();
395 
396  Type sourceElem = sourceType, destElem = destType;
397  uint32_t sourceLen = 1, destLen = 1;
398  if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
399  sourceLen = sourceVector.getNumElements();
400  sourceElem = sourceVector.getElementType();
401  }
402  if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
403  destLen = destVector.getNumElements();
404  destElem = destVector.getElementType();
405  }
406 
407  Type sourceBType = getSourceB().getType();
408  if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
409  int64_t sourceBLen = 1;
410  Type sourceBElem = sourceBType;
411  if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
412  sourceBLen = sourceBVector.getNumElements();
413  sourceBElem = sourceBVector.getElementType();
414  }
415  if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
416  !sourceBElem.isFloat(4))
417  return emitOpError("expected both source operands to have small-float "
418  "elements if one does");
419  if (sourceLen != sourceBLen)
420  return emitOpError(
421  "expected both small-float source vectors to have the same length");
422  } else {
423  if (sourceType != sourceBType)
424  return emitOpError("expected both non-small-float source operand types "
425  "to match exactly");
426  }
427  // Normalize the wider integer types the compiler expects to i8
428  if (sourceElem.isInteger(32)) {
429  sourceLen *= 4;
430  sourceElem = b.getI8Type();
431  }
432  if (sourceElem.isInteger(64)) {
433  sourceLen *= 8;
434  sourceElem = b.getI8Type();
435  }
436 
437  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
438  if (sourceLen != numSourceElems)
439  return emitOpError("expected " + Twine(numSourceElems) +
440  " source values for this operation but got " +
441  Twine(sourceLen));
442 
443  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
444  if (destLen != numDestElems)
445  return emitOpError("expected " + Twine(numDestElems) +
446  " result values for this operation but got " +
447  Twine(destLen));
448 
449  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
450  return emitOpError(
451  "double-precision ops do not support permuting lanes of B");
452  if (destElem.isF64() && getCbsz() != 0)
453  return emitOpError(
454  "double-precision ops do not support permuting lanes of A");
455  if (getAbid() >= (1u << getCbsz()))
456  return emitOpError(
457  "block ID for permuting A (abid) must be below 2 ** cbsz");
458 
459  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
460  return emitOpError(
461  "negation flags only available for double-precision operations");
462 
463  return success();
464 }
465 
466 //===----------------------------------------------------------------------===//
467 // DPPOp
468 //===----------------------------------------------------------------------===//
469 LogicalResult DPPOp::verify() {
470  Type srcType = getSrc().getType();
471  if (srcType.getIntOrFloatBitWidth() > 64) {
472  return emitOpError("integer and floating point types larger than 64 bits "
473  "are not supported");
474  }
475 
476  DPPPerm kind = getKind();
477  Attribute permArgument = getPermArgument().value_or(Attribute{});
478 
479  switch (kind) {
480 
481  case DPPPerm::quad_perm: {
482  auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
483  if (!quadPermAttr || quadPermAttr.size() != 4) {
484  return emitOpError("quad_perm attribute must have exactly 4 elements");
485  }
486  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
487  int32_t num = elem.getInt();
488  if (num < 0 || num > 3) {
489  return emitOpError(
490  "Each element of quad_perm must be in the range [0, 3]");
491  }
492  }
493  } break;
494 
495  case DPPPerm::row_shl:
496  case DPPPerm::row_shr:
497  case DPPPerm::row_ror: {
498  if (!permArgument) {
499  return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
500  "' value not specified");
501  }
502  if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
503  uint32_t attrValue = intAttr.getInt();
504  if (attrValue < 1 || attrValue > 15) {
505  return emitOpError("Attribute value must be between 1 and 15");
506  }
507  }
508  } break;
509 
510  case DPPPerm::wave_shl:
511  case DPPPerm::wave_shr:
512  case DPPPerm::wave_rol:
513  case DPPPerm::wave_ror:
514  case DPPPerm::row_mirror:
515  case DPPPerm::row_half_mirror:
516  case DPPPerm::row_bcast_15:
517  case DPPPerm::row_bcast_31: {
518  if (permArgument && !isa<UnitAttr>(permArgument)) {
519  return emitOpError("Expected unit attribute for permArgument, but found "
520  "non-trivial argument");
521  }
522  break;
523  }
524  }
525  return success();
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // PermlaneSwapOp
530 //===----------------------------------------------------------------------===//
531 LogicalResult PermlaneSwapOp::verify() {
532  unsigned rowLength = getRowLength();
533 
534  if (rowLength != 16 && rowLength != 32)
535  return emitOpError("row_length attribute must either be 16 or 32.");
536 
537  return success();
538 }
539 
540 //===----------------------------------------------------------------------===//
541 // GatherToLDSOp
542 //===----------------------------------------------------------------------===//
543 
544 LogicalResult GatherToLDSOp::verify() {
545  MemRefType srcType = cast<MemRefType>(getSrc().getType());
546  MemRefType dstType = cast<MemRefType>(getDst().getType());
547 
548  if (!dstType.areTrailingDimsContiguous(1))
549  return emitOpError("destination type inner most dim must be contiguous");
550 
551  auto elemType = srcType.getElementType();
552  // Check $src and $dst element types are the same.
553  if (elemType != dstType.getElementType())
554  return emitOpError("source and destination element types must match");
555 
556  // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
557  auto transferType = getTransferType();
558  int transferSize;
559  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
560  transferSize = vectorTransfer.getNumElements() *
561  vectorTransfer.getElementTypeBitWidth();
562  } else {
563  transferSize = transferType.getIntOrFloatBitWidth();
564  }
565  if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
566  return emitOpError(
567  "Transfering type size must be 8, 16, 32, 96 or 128 bits");
568 
569  if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
570  !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
571  return emitOpError(
572  "source memory address space must be global or fat raw buffer");
573 
574  if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
575  return emitOpError("destination memory address space must be Workgroup");
576 
577  return success();
578 }
579 
580 namespace {
581 /// If the source/target of a GatherToLDSOp is a CastOp that only removes static
582 /// information or changes layout, the cast can be skipped.
583 struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
585 
586  LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
587  PatternRewriter &rewriter) const override {
588  bool modified = false;
589  auto foldCast = [&](OpOperand &operand) {
590  if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
592  rewriter.modifyOpInPlace(gatherOp,
593  [&] { operand.assign(castOp.getSource()); });
594  modified = true;
595  }
596  }
597  };
598 
599  foldCast(gatherOp.getSrcMutable());
600  foldCast(gatherOp.getDstMutable());
601 
602  return success(modified);
603  }
604 };
605 } // namespace
606 
607 void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
608  MLIRContext *context) {
609  results.add<FoldGatherToLDSOfCast>(context);
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // TransposeLoadOp
614 //===----------------------------------------------------------------------===//
615 
616 LogicalResult TransposeLoadOp::verify() {
617  MemRefType srcType = cast<MemRefType>(getSrc().getType());
618 
619  if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
620  return emitOpError("source memory address space must be Workgroup");
621 
622  auto transferType = cast<VectorType>(getType());
623  size_t numElements = transferType.getNumElements();
624  size_t elementTypeSize =
625  transferType.getElementType().getIntOrFloatBitWidth();
626 
627  // ElementSize -> NumElements
628  const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
629  {4, 16},
630  {6, 16},
631  {8, 8},
632  {16, 4},
633  };
634 
635  auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
636  if (validNumElems == kValidLoadSizeMap.end()) {
637  return emitOpError("Unsupported element type size for transpose load: ")
638  << elementTypeSize << " bits";
639  }
640  if (numElements != validNumElems->second) {
641  return emitOpError(
642  "Transferring type size mismatch: expected num of elements: ")
643  << validNumElems->second;
644  }
645 
646  return success();
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // ScaledMFMAOp
651 //===----------------------------------------------------------------------===//
652 
653 namespace {
654 /// Check if the scales input is used in other scaled mfma's while they exist.
655 /// If theyre unused then pack the scales.
656 struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
658 
659  LogicalResult matchAndRewrite(ScaledMFMAOp op,
660  PatternRewriter &rewriter) const override {
661  Location loc = op.getLoc();
662  auto setOpsel = [&op](unsigned idx, int64_t val) {
663  switch (idx) {
664  case 3:
665  op.setScalesIdxA(val);
666  break;
667  case 4:
668  op.setScalesIdxB(val);
669  break;
670  default:
671  break;
672  }
673  };
674 
675  // For every scale operand of this ScaledMFMAOp, if the scale is produced by
676  // the extraction of a single scale from some vector, then attempt to
677  // extract 4 values from that vector instead.
678  //
679  // Example: (f8 here means f8E8M0FNU)
680  // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
681  // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
682  // amdgpu.scaled_mfma(%scale[0] * ...
683  //
684  // rewrite to:
685  //
686  // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
687  // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
688  // amdgpu.scaled_mfma(%scale[0-3] * ...
689  //
690  // This creates duplicate shape_casts for every use but these will be
691  // removed in CSE.
692  for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
693  auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
694  if (!insertOp) {
695  return rewriter.notifyMatchFailure(op,
696  "defining op not a vector.insert");
697  }
698  // If the extracted value is not a single scalar, then it has been packed.
699  if (isa<VectorType>(insertOp.getValueToStore().getType())) {
700  return rewriter.notifyMatchFailure(
701  op, "scaled mfma operand already packed");
702  }
703 
704  auto extractOp =
705  insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
706  if (!extractOp) {
707  return rewriter.notifyMatchFailure(op,
708  "defining op not a vector.extract");
709  }
710 
711  Value scaleSrc = extractOp.getOperand(0);
712  auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
713  if (!scaleSrcType) {
714  return rewriter.notifyMatchFailure(op, "not a vector type");
715  }
716 
717  // We do not handle dynamic dims yet, assume that the input is padded to
718  // a static shape now.
719  if (!scaleSrcType.hasStaticShape()) {
720  return rewriter.notifyMatchFailure(op,
721  "dynamic dims not yet supported");
722  }
723 
724  int64_t numElements = scaleSrcType.getNumElements();
725  if (numElements <= 4) {
726  return rewriter.notifyMatchFailure(
727  op, "no packing if # of scales less than four");
728  }
729 
730  // Find a linearized idx using the size and offsets of the extract op.
731  auto extractedPos = llvm::to_vector_of<int64_t>(
732  llvm::reverse(extractOp.getStaticPosition()));
733  ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
734  int64_t scaleSrcRank = scaleSrcType.getRank();
735  SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
736  for (int64_t i = 1; i < scaleSrcRank; ++i) {
737  extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
738  }
739  int64_t idx = linearize(extractedPos, extractSizes);
740 
741  // All n scales (where n is the total number of scales) must now be
742  // extracted in chunks of 4 elements. This is done by dividing the
743  // original vector of scales into groups of 4 elements
744  // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
745  // scale at a particular index are now replaced with an extraction
746  // of the entire group of 4 elements to which that index belongs.
747  //
748  // If the number of scales happens to be indivisible by 4, extract
749  // the remaining n - m scales in a chunk of 4 elements starting at
750  // offset n - 4.
751  int64_t offset = idx - (idx % 4);
752  int64_t opsel = idx - offset;
753  int64_t size = 4l;
754  // Accomdate remaining elements in the case of non-4-divisible vectors.
755  if (numElements - offset < size) {
756  opsel = size - (numElements - idx);
757  offset = numElements - 4l;
758  }
759  Type scaleSrcElemType = scaleSrcType.getElementType();
760  auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
761  scaleSrcElemType);
762  Value newScaleSrc =
763  vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
764  auto extract = vector::ExtractStridedSliceOp::create(
765  rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset},
767  rewriter.modifyOpInPlace(op, [&] {
768  op->setOperand(opIdx, extract);
769  setOpsel(opIdx, opsel);
770  });
771  }
772  return success();
773  }
774 };
775 } // namespace
776 
777 void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
778  MLIRContext *context) {
779  results.add<PackScales>(context);
780 }
781 
782 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
783 
784 #define GET_ATTRDEF_CLASSES
785 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
786 
787 #define GET_OP_CLASSES
788 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
This class represents an operand of an operation.
Definition: Value.h:257
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition: Types.cpp:45
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322