MLIR  22.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Diagnostics.h"
24 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 
31 #include <limits>
32 #include <optional>
33 
34 using namespace mlir;
35 using namespace mlir::amdgpu;
36 
37 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
38 
39 void AMDGPUDialect::initialize() {
40  addOperations<
41 #define GET_OP_LIST
42 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
43  >();
44  addAttributes<
45 #define GET_ATTRDEF_LIST
46 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
47  >();
48 }
49 
50 //===----------------------------------------------------------------------===//
51 // 8-bit float ops
52 //===----------------------------------------------------------------------===//
53 LogicalResult PackedTrunc2xFp8Op::verify() {
54  if (getExisting() && getExisting().getType() != getResult().getType())
55  return emitOpError("existing values must have same type as result");
56  return success();
57 }
58 
59 LogicalResult PackedStochRoundFp8Op::verify() {
60  if (getExisting() && getExisting().getType() != getResult().getType())
61  return emitOpError("existing values must have same type as result");
62  return success();
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // mxfp float ops
67 //===----------------------------------------------------------------------===//
68 LogicalResult PackedScaledTruncOp::verify() {
69  if (getExisting() && getExisting().getType() != getResult().getType())
70  return emitOpError("existing values must have same type as result");
71  return success();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // FatRawBufferCastOp
76 //===----------------------------------------------------------------------===//
77 
78 /// Convert the type `source` to one with the same sizes and strides - and
79 /// offset, unless `stripOffset` is true, in which case the offset is reset to
80 /// 0, if the offset should be reset but the layout of `source` isn't either the
81 /// identity layout or a strided layout, this function fails.
82 static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
83  bool resetOffset) {
84  MLIRContext *ctx = source.getContext();
85  MemRefType::Builder mb(source);
86  mb.setMemorySpace(
87  amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
88  MemRefLayoutAttrInterface layout = source.getLayout();
89  if (resetOffset && !layout.isIdentity()) {
90  auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
91  if (!stridedLayout)
92  return failure();
93  MemRefLayoutAttrInterface newLayout =
94  StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
95  // Special case: if resetting the offset causes the strided layout to become
96  // the identity layout, then reset to the identity layout.
97  // TODO: this'll get a lot simpler when we have the contiguous layout.
98  SmallVector<int64_t> stridesIfIdentity;
99  if (source.hasStaticShape()) {
100  stridesIfIdentity = computeSuffixProduct(source.getShape());
101  } else if (source.getRank() <= 1) {
102  stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
103  }
104  if (stridesIfIdentity == stridedLayout.getStrides()) {
105  newLayout = AffineMapAttr::get(
106  AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
107  }
108  mb.setLayout(newLayout);
109  }
110  return (MemRefType)(mb);
111 }
112 
113 LogicalResult FatRawBufferCastOp::inferReturnTypes(
114  MLIRContext *context, std::optional<Location> location, ValueRange operands,
115  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
116  SmallVectorImpl<Type> &inferredReturnTypes) {
117  Adaptor adaptor(operands, attributes, properties, regions);
118  auto sourceType =
119  dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
120  if (!sourceType)
121  return failure();
122  FailureOr<MemRefType> resultType =
123  getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
124  if (failed(resultType))
125  return failure();
126  inferredReturnTypes = SmallVector<Type>{*resultType};
127  return success();
128 }
129 
130 LogicalResult FatRawBufferCastOp::verify() {
131  FailureOr<MemRefType> expectedResultType =
132  getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
133  if (failed(expectedResultType))
134  return emitOpError("source type ")
135  << getSource().getType() << " can't have its offset reset";
136  if (getResult().getType() != *expectedResultType)
137  return emitOpError("expected result type to be ")
138  << *expectedResultType << " but got " << getResult().getType();
139  return success();
140 }
141 
142 static bool hasGlobalMemorySpace(Attribute memorySpace) {
143  if (!memorySpace)
144  return true;
145  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
146  return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
147  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
148  return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
149  return false;
150 }
151 
152 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
153  if (!memorySpace)
154  return false;
155  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
156  return intMemorySpace.getInt() == 3;
157  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
158  return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
159  return false;
160 }
161 
162 static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
163  if (!memorySpace)
164  return false;
165  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
166  return intMemorySpace.getInt() == 7;
167  if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
168  return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
169  return false;
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // RawBuffer*Op
174 //===----------------------------------------------------------------------===//
175 template <typename T>
176 static LogicalResult verifyRawBufferOp(T &op) {
177  MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
178  bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
179 
180  if (!isGlobal)
181  return op.emitOpError(
182  "Buffer ops must operate on a memref in global memory");
183  if (!bufferType.hasRank())
184  return op.emitOpError(
185  "Cannot meaningfully buffer_store to an unranked memref");
186  if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
187  return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
188  " indices to memref");
189  return success();
190 }
191 
192 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
193 
194 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
195 
196 LogicalResult RawBufferAtomicFaddOp::verify() {
197  return verifyRawBufferOp(*this);
198 }
199 
200 LogicalResult RawBufferAtomicFmaxOp::verify() {
201  return verifyRawBufferOp(*this);
202 }
203 
204 LogicalResult RawBufferAtomicSmaxOp::verify() {
205  return verifyRawBufferOp(*this);
206 }
207 
208 LogicalResult RawBufferAtomicUminOp::verify() {
209  return verifyRawBufferOp(*this);
210 }
211 
212 LogicalResult RawBufferAtomicCmpswapOp::verify() {
213  return verifyRawBufferOp(*this);
214 }
215 
216 static std::optional<uint32_t> getConstantUint32(Value v) {
217  APInt cst;
218  if (!v.getType().isInteger(32))
219  return std::nullopt;
220  if (matchPattern(v, m_ConstantInt(&cst)))
221  return cst.getZExtValue();
222  return std::nullopt;
223 }
224 
225 template <typename OpType>
226 static bool staticallyOutOfBounds(OpType op) {
227  if (!op.getBoundsCheck())
228  return false;
229  MemRefType bufferType = op.getMemref().getType();
230  if (!bufferType.hasStaticShape())
231  return false;
232  int64_t offset;
233  SmallVector<int64_t> strides;
234  if (failed(bufferType.getStridesAndOffset(strides, offset)))
235  return false;
236  int64_t result = offset + op.getIndexOffset().value_or(0);
237  if (op.getSgprOffset()) {
238  std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
239  if (!sgprOffset)
240  return false;
241  result += *sgprOffset;
242  }
243  if (strides.size() != op.getIndices().size())
244  return false;
245  int64_t indexVal = 0;
246  for (auto pair : llvm::zip(strides, op.getIndices())) {
247  int64_t stride = std::get<0>(pair);
248  Value idx = std::get<1>(pair);
249  std::optional<uint32_t> idxVal = getConstantUint32(idx);
250  if (!idxVal)
251  return false;
252  indexVal += stride * *idxVal;
253  }
254  result += indexVal;
255  if (result > std::numeric_limits<uint32_t>::max())
256  // Overflow means don't drop
257  return false;
258  return result >= bufferType.getNumElements();
259 }
260 
261 namespace {
262 template <typename OpType>
263 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
265 
266  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
267  if (!staticallyOutOfBounds(op))
268  return failure();
269  Type loadType = op.getResult().getType();
270  rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
271  rw.getZeroAttr(loadType));
272  return success();
273  }
274 };
275 
276 template <typename OpType>
277 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
279 
280  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
281  if (!staticallyOutOfBounds(op))
282  return failure();
283 
284  rw.eraseOp(op);
285  return success();
286  }
287 };
288 } // end namespace
289 
290 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
291  MLIRContext *context) {
292  results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
293 }
294 
295 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
296  MLIRContext *context) {
297  results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
298 }
299 
300 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
301  RewritePatternSet &results, MLIRContext *context) {
302  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
303 }
304 
305 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
306  RewritePatternSet &results, MLIRContext *context) {
307  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
308 }
309 
310 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
311  RewritePatternSet &results, MLIRContext *context) {
312  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
313 }
314 
315 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
316  RewritePatternSet &results, MLIRContext *context) {
317  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
318 }
319 
320 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
321  RewritePatternSet &results, MLIRContext *context) {
322  results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
323  context);
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // WMMAOp
328 //===----------------------------------------------------------------------===//
329 LogicalResult WMMAOp::verify() {
330  Type sourceAType = getSourceA().getType();
331  Type sourceBType = getSourceB().getType();
332  Type destType = getDestC().getType();
333 
334  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
335  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
336  VectorType destVectorType = dyn_cast<VectorType>(destType);
337 
338  Type sourceAElemType = sourceVectorAType.getElementType();
339  Type sourceBElemType = sourceVectorBType.getElementType();
340  Type destElemType = destVectorType.getElementType();
341 
342  if (sourceVectorAType.getNumElements() !=
343  sourceVectorBType.getNumElements()) {
344  return emitOpError("source vectors have different lengths: ")
345  << sourceVectorAType << " vs. " << sourceVectorBType;
346  }
347 
348  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
349  bool isSrcFloat =
350  isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
351  sourceAElemType);
352 
353  if (isDestFloat && !isSrcFloat) {
354  return emitOpError("Expected float sources with float destination");
355  }
356 
357  if (!isDestFloat && isSrcFloat) {
358  return emitOpError("Expected int sources with int destination");
359  }
360 
361  if (sourceAElemType != sourceBElemType &&
362  !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
363  isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
364  return emitOpError(
365  "source element types much match (except for fp8) but have ")
366  << sourceAType << " and " << sourceBType;
367  }
368  return success();
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // MFMAOp
373 //===----------------------------------------------------------------------===//
374 LogicalResult MFMAOp::verify() {
375  constexpr uint32_t waveSize = 64;
376  Builder b(getContext());
377 
378  Type sourceType = getSourceA().getType();
379  Type destType = getDestC().getType();
380 
381  Type sourceElem = sourceType, destElem = destType;
382  uint32_t sourceLen = 1, destLen = 1;
383  if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
384  sourceLen = sourceVector.getNumElements();
385  sourceElem = sourceVector.getElementType();
386  }
387  if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
388  destLen = destVector.getNumElements();
389  destElem = destVector.getElementType();
390  }
391 
392  Type sourceBType = getSourceB().getType();
393  if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
394  int64_t sourceBLen = 1;
395  Type sourceBElem = sourceBType;
396  if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
397  sourceBLen = sourceBVector.getNumElements();
398  sourceBElem = sourceBVector.getElementType();
399  }
400  if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
401  !sourceBElem.isFloat(4))
402  return emitOpError("expected both source operands to have small-float "
403  "elements if one does");
404  if (sourceLen != sourceBLen)
405  return emitOpError(
406  "expected both small-float source vectors to have the same length");
407  } else {
408  if (sourceType != sourceBType)
409  return emitOpError("expected both non-small-float source operand types "
410  "to match exactly");
411  }
412  // Normalize the wider integer types the compiler expects to i8
413  if (sourceElem.isInteger(32)) {
414  sourceLen *= 4;
415  sourceElem = b.getI8Type();
416  }
417  if (sourceElem.isInteger(64)) {
418  sourceLen *= 8;
419  sourceElem = b.getI8Type();
420  }
421 
422  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
423  if (sourceLen != numSourceElems)
424  return emitOpError("expected " + Twine(numSourceElems) +
425  " source values for this operation but got " +
426  Twine(sourceLen));
427 
428  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
429  if (destLen != numDestElems)
430  return emitOpError("expected " + Twine(numDestElems) +
431  " result values for this operation but got " +
432  Twine(destLen));
433 
434  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
435  return emitOpError(
436  "double-precision ops do not support permuting lanes of B");
437  if (destElem.isF64() && getCbsz() != 0)
438  return emitOpError(
439  "double-precision ops do not support permuting lanes of A");
440  if (getAbid() >= (1u << getCbsz()))
441  return emitOpError(
442  "block ID for permuting A (abid) must be below 2 ** cbsz");
443 
444  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
445  return emitOpError(
446  "negation flags only available for double-precision operations");
447 
448  return success();
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // DPPOp
453 //===----------------------------------------------------------------------===//
454 LogicalResult DPPOp::verify() {
455  Type srcType = getSrc().getType();
456  if (srcType.getIntOrFloatBitWidth() > 64) {
457  return emitOpError("integer and floating point types larger than 64 bits "
458  "are not supported");
459  }
460 
461  DPPPerm kind = getKind();
462  Attribute permArgument = getPermArgument().value_or(Attribute{});
463 
464  switch (kind) {
465 
466  case DPPPerm::quad_perm: {
467  auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
468  if (!quadPermAttr || quadPermAttr.size() != 4) {
469  return emitOpError("quad_perm attribute must have exactly 4 elements");
470  }
471  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
472  int32_t num = elem.getInt();
473  if (num < 0 || num > 3) {
474  return emitOpError(
475  "Each element of quad_perm must be in the range [0, 3]");
476  }
477  }
478  } break;
479 
480  case DPPPerm::row_shl:
481  case DPPPerm::row_shr:
482  case DPPPerm::row_ror: {
483  if (!permArgument) {
484  return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
485  "' value not specified");
486  }
487  if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
488  uint32_t attrValue = intAttr.getInt();
489  if (attrValue < 1 || attrValue > 15) {
490  return emitOpError("Attribute value must be between 1 and 15");
491  }
492  }
493  } break;
494 
495  case DPPPerm::wave_shl:
496  case DPPPerm::wave_shr:
497  case DPPPerm::wave_rol:
498  case DPPPerm::wave_ror:
499  case DPPPerm::row_mirror:
500  case DPPPerm::row_half_mirror:
501  case DPPPerm::row_bcast_15:
502  case DPPPerm::row_bcast_31: {
503  if (permArgument && !isa<UnitAttr>(permArgument)) {
504  return emitOpError("Expected unit attribute for permArgument, but found "
505  "non-trivial argument");
506  }
507  break;
508  }
509  }
510  return success();
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // GatherToLDSOp
515 //===----------------------------------------------------------------------===//
516 
517 LogicalResult GatherToLDSOp::verify() {
518  MemRefType srcType = cast<MemRefType>(getSrc().getType());
519  MemRefType dstType = cast<MemRefType>(getDst().getType());
520 
521  if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
522  return emitOpError("destination types must be contiguous");
523 
524  auto elemType = srcType.getElementType();
525  // Check $src and $dst element types are the same.
526  if (elemType != dstType.getElementType())
527  return emitOpError("source and destination element types must match");
528 
529  // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
530  auto transferType = getTransferType();
531  int transferSize;
532  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
533  transferSize = vectorTransfer.getNumElements() *
534  vectorTransfer.getElementTypeBitWidth();
535  } else {
536  transferSize = transferType.getIntOrFloatBitWidth();
537  }
538  if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
539  return emitOpError(
540  "Transfering type size must be 8, 16, 32, 96 or 128 bits");
541 
542  if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
543  !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
544  return emitOpError(
545  "source memory address space must be global or fat raw buffer");
546 
547  if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
548  return emitOpError("destination memory address space must be Workgroup");
549 
550  return success();
551 }
552 
553 namespace {
554 /// If the source/target of a GatherToLDSOp is a CastOp that only removes static
555 /// information or changes layout, the cast can be skipped.
556 struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
558 
559  LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
560  PatternRewriter &rewriter) const override {
561  bool modified = false;
562  auto foldCast = [&](OpOperand &operand) {
563  if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
565  rewriter.modifyOpInPlace(gatherOp,
566  [&] { operand.assign(castOp.getSource()); });
567  modified = true;
568  }
569  }
570  };
571 
572  foldCast(gatherOp.getSrcMutable());
573  foldCast(gatherOp.getDstMutable());
574 
575  return success(modified);
576  }
577 };
578 } // namespace
579 
580 void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
581  MLIRContext *context) {
582  results.add<FoldGatherToLDSOfCast>(context);
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // TransposeLoadOp
587 //===----------------------------------------------------------------------===//
588 
589 LogicalResult TransposeLoadOp::verify() {
590  MemRefType srcType = cast<MemRefType>(getSrc().getType());
591 
592  if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
593  return emitOpError("source memory address space must be Workgroup");
594 
595  auto transferType = cast<VectorType>(getType());
596  size_t numElements = transferType.getNumElements();
597  size_t elementTypeSize =
598  transferType.getElementType().getIntOrFloatBitWidth();
599 
600  // ElementSize -> NumElements
601  const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
602  {4, 16},
603  {6, 16},
604  {8, 8},
605  {16, 4},
606  };
607 
608  auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
609  if (validNumElems == KValidLoadSizeMap.end()) {
610  return emitOpError("Unsupported element type size for transpose load: ")
611  << elementTypeSize << " bits";
612  }
613  if (numElements != validNumElems->second) {
614  return emitOpError(
615  "Transferring type size mismatch: expected num of elements: ")
616  << validNumElems->second;
617  }
618 
619  return success();
620 }
621 
622 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
623 
624 #define GET_ATTRDEF_CLASSES
625 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
626 
627 #define GET_OP_CLASSES
628 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1225::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
This class represents an operand of an operation.
Definition: Value.h:257
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition: Types.cpp:45
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319