MLIR  21.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Diagnostics.h"
23 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 
30 #include <limits>
31 #include <optional>
32 
33 using namespace mlir;
34 using namespace mlir::amdgpu;
35 
36 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
37 
38 void AMDGPUDialect::initialize() {
39  addOperations<
40 #define GET_OP_LIST
41 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
42  >();
43  addAttributes<
44 #define GET_ATTRDEF_LIST
45 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
46  >();
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // 8-bit float ops
51 //===----------------------------------------------------------------------===//
52 LogicalResult PackedTrunc2xFp8Op::verify() {
53  if (getExisting() && getExisting().getType() != getResult().getType())
54  return emitOpError("existing values must have same type as result");
55  return success();
56 }
57 
58 LogicalResult PackedStochRoundFp8Op::verify() {
59  if (getExisting() && getExisting().getType() != getResult().getType())
60  return emitOpError("existing values must have same type as result");
61  return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // mxfp float ops
66 //===----------------------------------------------------------------------===//
67 LogicalResult PackedScaledTruncOp::verify() {
68  if (getExisting() && getExisting().getType() != getResult().getType())
69  return emitOpError("existing values must have same type as result");
70  return success();
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // FatRawBufferCastOp
75 //===----------------------------------------------------------------------===//
76 
77 /// Convert the type `source` to one with the same sizes and strides - and
78 /// offset, unless `stripOffset` is true, in which case the offset is reset to
79 /// 0, if the offset should be reset but the layout of `source` isn't either the
80 /// identity layout or a strided layout, this function fails.
81 static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
82  bool resetOffset) {
83  MLIRContext *ctx = source.getContext();
84  MemRefType::Builder mb(source);
85  mb.setMemorySpace(
86  amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
87  MemRefLayoutAttrInterface layout = source.getLayout();
88  if (resetOffset && !layout.isIdentity()) {
89  auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
90  if (!stridedLayout)
91  return failure();
92  mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
93  }
94  return (MemRefType)(mb);
95 }
96 
97 LogicalResult FatRawBufferCastOp::inferReturnTypes(
98  MLIRContext *context, std::optional<Location> location, ValueRange operands,
99  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
100  SmallVectorImpl<Type> &inferredReturnTypes) {
101  Adaptor adaptor(operands, attributes, properties, regions);
102  auto sourceType =
103  dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
104  if (!sourceType)
105  return failure();
106  FailureOr<MemRefType> resultType =
107  getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
108  if (failed(resultType))
109  return failure();
110  inferredReturnTypes = SmallVector<Type>{*resultType};
111  return success();
112 }
113 
114 LogicalResult FatRawBufferCastOp::verify() {
115  FailureOr<MemRefType> expectedResultType =
116  getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
117  if (failed(expectedResultType))
118  return emitOpError("source type ")
119  << getSource().getType() << " can't have its offset reset";
120  if (getResult().getType() != *expectedResultType)
121  return emitOpError("expected result type to be ")
122  << *expectedResultType << " but got " << getResult().getType();
123  return success();
124 }
125 
126 static bool hasGlobalMemorySpace(Attribute memorySpace) {
127  if (!memorySpace)
128  return true;
129  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
130  return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
131  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
132  return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
133  return false;
134 }
135 
136 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
137  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
138  return intMemorySpace.getInt() == 3;
139  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
140  return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
141  return false;
142 }
143 
144 static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
145  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
146  return intMemorySpace.getInt() == 7;
147  if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
148  return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
149  return false;
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // RawBuffer*Op
154 //===----------------------------------------------------------------------===//
155 template <typename T>
156 static LogicalResult verifyRawBufferOp(T &op) {
157  MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
158  bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
159 
160  if (!isGlobal)
161  return op.emitOpError(
162  "Buffer ops must operate on a memref in global memory");
163  if (!bufferType.hasRank())
164  return op.emitOpError(
165  "Cannot meaningfully buffer_store to an unranked memref");
166  if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
167  return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
168  " indices to memref");
169  return success();
170 }
171 
172 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
173 
174 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
175 
176 LogicalResult RawBufferAtomicFaddOp::verify() {
177  return verifyRawBufferOp(*this);
178 }
179 
180 LogicalResult RawBufferAtomicFmaxOp::verify() {
181  return verifyRawBufferOp(*this);
182 }
183 
184 LogicalResult RawBufferAtomicSmaxOp::verify() {
185  return verifyRawBufferOp(*this);
186 }
187 
188 LogicalResult RawBufferAtomicUminOp::verify() {
189  return verifyRawBufferOp(*this);
190 }
191 
192 LogicalResult RawBufferAtomicCmpswapOp::verify() {
193  return verifyRawBufferOp(*this);
194 }
195 
196 static std::optional<uint32_t> getConstantUint32(Value v) {
197  APInt cst;
198  if (!v.getType().isInteger(32))
199  return std::nullopt;
200  if (matchPattern(v, m_ConstantInt(&cst)))
201  return cst.getZExtValue();
202  return std::nullopt;
203 }
204 
205 template <typename OpType>
206 static bool staticallyOutOfBounds(OpType op) {
207  if (!op.getBoundsCheck())
208  return false;
209  MemRefType bufferType = op.getMemref().getType();
210  if (!bufferType.hasStaticShape())
211  return false;
212  int64_t offset;
213  SmallVector<int64_t> strides;
214  if (failed(bufferType.getStridesAndOffset(strides, offset)))
215  return false;
216  int64_t result = offset + op.getIndexOffset().value_or(0);
217  if (op.getSgprOffset()) {
218  std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
219  if (!sgprOffset)
220  return false;
221  result += *sgprOffset;
222  }
223  if (strides.size() != op.getIndices().size())
224  return false;
225  int64_t indexVal = 0;
226  for (auto pair : llvm::zip(strides, op.getIndices())) {
227  int64_t stride = std::get<0>(pair);
228  Value idx = std::get<1>(pair);
229  std::optional<uint32_t> idxVal = getConstantUint32(idx);
230  if (!idxVal)
231  return false;
232  indexVal += stride * *idxVal;
233  }
234  result += indexVal;
235  if (result > std::numeric_limits<uint32_t>::max())
236  // Overflow means don't drop
237  return false;
238  return result >= bufferType.getNumElements();
239 }
240 
241 namespace {
242 template <typename OpType>
243 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
245 
246  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
247  if (!staticallyOutOfBounds(op))
248  return failure();
249  Type loadType = op.getResult().getType();
250  rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
251  rw.getZeroAttr(loadType));
252  return success();
253  }
254 };
255 
256 template <typename OpType>
257 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
259 
260  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
261  if (!staticallyOutOfBounds(op))
262  return failure();
263 
264  rw.eraseOp(op);
265  return success();
266  }
267 };
268 } // end namespace
269 
270 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
271  MLIRContext *context) {
272  results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
273 }
274 
275 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
276  MLIRContext *context) {
277  results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
278 }
279 
280 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
281  RewritePatternSet &results, MLIRContext *context) {
282  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
283 }
284 
285 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
286  RewritePatternSet &results, MLIRContext *context) {
287  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
288 }
289 
290 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
291  RewritePatternSet &results, MLIRContext *context) {
292  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
293 }
294 
295 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
296  RewritePatternSet &results, MLIRContext *context) {
297  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
298 }
299 
300 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
301  RewritePatternSet &results, MLIRContext *context) {
302  results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
303  context);
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // WMMAOp
308 //===----------------------------------------------------------------------===//
309 LogicalResult WMMAOp::verify() {
310  Type sourceAType = getSourceA().getType();
311  Type sourceBType = getSourceB().getType();
312  Type destType = getDestC().getType();
313 
314  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
315  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
316  VectorType destVectorType = dyn_cast<VectorType>(destType);
317 
318  Type sourceAElemType = sourceVectorAType.getElementType();
319  Type sourceBElemType = sourceVectorBType.getElementType();
320  Type destElemType = destVectorType.getElementType();
321 
322  if (sourceVectorAType.getNumElements() !=
323  sourceVectorBType.getNumElements()) {
324  return emitOpError("source vectors have different lengths: ")
325  << sourceVectorAType << " vs. " << sourceVectorBType;
326  }
327 
328  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
329  bool isSrcFloat =
330  isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
331  sourceAElemType);
332 
333  if (isDestFloat && !isSrcFloat) {
334  return emitOpError("Expected float sources with float destination");
335  }
336 
337  if (!isDestFloat && isSrcFloat) {
338  return emitOpError("Expected int sources with int destination");
339  }
340 
341  if (sourceAElemType != sourceBElemType &&
342  !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
343  isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
344  return emitOpError(
345  "source element types much match (except for fp8) but have ")
346  << sourceAType << " and " << sourceBType;
347  }
348  return success();
349 }
350 
351 //===----------------------------------------------------------------------===//
352 // MFMAOp
353 //===----------------------------------------------------------------------===//
354 LogicalResult MFMAOp::verify() {
355  constexpr uint32_t waveSize = 64;
356  Builder b(getContext());
357 
358  Type sourceType = getSourceA().getType();
359  Type destType = getDestC().getType();
360 
361  Type sourceElem = sourceType, destElem = destType;
362  uint32_t sourceLen = 1, destLen = 1;
363  if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
364  sourceLen = sourceVector.getNumElements();
365  sourceElem = sourceVector.getElementType();
366  }
367  if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
368  destLen = destVector.getNumElements();
369  destElem = destVector.getElementType();
370  }
371 
372  Type sourceBType = getSourceB().getType();
373  if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
374  int64_t sourceBLen = 1;
375  Type sourceBElem = sourceBType;
376  if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
377  sourceBLen = sourceBVector.getNumElements();
378  sourceBElem = sourceBVector.getElementType();
379  }
380  if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
381  !sourceBElem.isFloat(4))
382  return emitOpError("expected both source operands to have small-float "
383  "elements if one does");
384  if (sourceLen != sourceBLen)
385  return emitOpError(
386  "expected both small-float source vectors to have the same length");
387  } else {
388  if (sourceType != sourceBType)
389  return emitOpError("expected both non-small-float source operand types "
390  "to match exactly");
391  }
392  // Normalize the wider integer types the compiler expects to i8
393  if (sourceElem.isInteger(32)) {
394  sourceLen *= 4;
395  sourceElem = b.getI8Type();
396  }
397  if (sourceElem.isInteger(64)) {
398  sourceLen *= 8;
399  sourceElem = b.getI8Type();
400  }
401 
402  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
403  if (sourceLen != numSourceElems)
404  return emitOpError("expected " + Twine(numSourceElems) +
405  " source values for this operation but got " +
406  Twine(sourceLen));
407 
408  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
409  if (destLen != numDestElems)
410  return emitOpError("expected " + Twine(numDestElems) +
411  " result values for this operation but got " +
412  Twine(destLen));
413 
414  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
415  return emitOpError(
416  "double-precision ops do not support permuting lanes of B");
417  if (destElem.isF64() && getCbsz() != 0)
418  return emitOpError(
419  "double-precision ops do not support permuting lanes of A");
420  if (getAbid() >= (1u << getCbsz()))
421  return emitOpError(
422  "block ID for permuting A (abid) must be below 2 ** cbsz");
423 
424  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
425  return emitOpError(
426  "negation flags only available for double-precision operations");
427 
428  return success();
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // DPPOp
433 //===----------------------------------------------------------------------===//
434 LogicalResult DPPOp::verify() {
435  Type srcType = getSrc().getType();
436  if (srcType.getIntOrFloatBitWidth() > 64) {
437  return emitOpError("integer and floating point types larger than 64 bits "
438  "are not supported");
439  }
440 
441  DPPPerm kind = getKind();
442  Attribute permArgument = getPermArgument().value_or(Attribute{});
443 
444  switch (kind) {
445 
446  case DPPPerm::quad_perm: {
447  auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
448  if (!quadPermAttr || quadPermAttr.size() != 4) {
449  return emitOpError("quad_perm attribute must have exactly 4 elements");
450  }
451  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
452  int32_t num = elem.getInt();
453  if (num < 0 || num > 3) {
454  return emitOpError(
455  "Each element of quad_perm must be in the range [0, 3]");
456  }
457  }
458  } break;
459 
460  case DPPPerm::row_shl:
461  case DPPPerm::row_shr:
462  case DPPPerm::row_ror: {
463  if (!permArgument) {
464  return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
465  "' value not specified");
466  }
467  if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
468  uint32_t attrValue = intAttr.getInt();
469  if (attrValue < 1 || attrValue > 15) {
470  return emitOpError("Attribute value must be between 1 and 15");
471  }
472  }
473  } break;
474 
475  case DPPPerm::wave_shl:
476  case DPPPerm::wave_shr:
477  case DPPPerm::wave_rol:
478  case DPPPerm::wave_ror:
479  case DPPPerm::row_mirror:
480  case DPPPerm::row_half_mirror:
481  case DPPPerm::row_bcast_15:
482  case DPPPerm::row_bcast_31: {
483  if (permArgument && !isa<UnitAttr>(permArgument)) {
484  return emitOpError("Expected unit attribute for permArgument, but found "
485  "non-trivial argument");
486  }
487  break;
488  }
489  }
490  return success();
491 }
492 
493 LogicalResult GatherToLDSOp::verify() {
494  MemRefType srcType = cast<MemRefType>(getSrc().getType());
495  MemRefType dstType = cast<MemRefType>(getDst().getType());
496 
497  if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
498  return emitOpError("destination types must be contiguous");
499 
500  auto elemType = srcType.getElementType();
501  // Check $src and $dst element types are the same.
502  if (elemType != dstType.getElementType())
503  return emitOpError("source and destination element types must match");
504 
505  // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
506  auto transferType = getTransferType();
507  int transferSize;
508  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
509  transferSize = vectorTransfer.getNumElements() *
510  vectorTransfer.getElementTypeBitWidth();
511  } else {
512  transferSize = transferType.getIntOrFloatBitWidth();
513  }
514  if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
515  return emitOpError(
516  "Transfering type size must be 8, 16, 32, 96 or 128 bits");
517 
518  if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
519  !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
520  return emitOpError(
521  "source memory address space must be global or fat raw buffer");
522 
523  if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
524  return emitOpError("destination memory address space must be Workgroup");
525 
526  return success();
527 }
528 
529 LogicalResult TransposeLoadOp::verify() {
530  MemRefType srcType = cast<MemRefType>(getSrc().getType());
531 
532  if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
533  return emitOpError("source memory address space must be Workgroup");
534 
535  auto transferType = cast<VectorType>(getType());
536  size_t numElements = transferType.getNumElements();
537  size_t elementTypeSize =
538  transferType.getElementType().getIntOrFloatBitWidth();
539 
540  // ElementSize -> NumElements
541  const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
542  {4, 16},
543  {6, 16},
544  {8, 8},
545  {16, 4},
546  };
547 
548  auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
549  if (validNumElems == KValidLoadSizeMap.end()) {
550  return emitOpError("Unsupported element type size for transpose load: ")
551  << elementTypeSize << " bits";
552  }
553  if (numElements != validNumElems->second) {
554  return emitOpError(
555  "Transferring type size mismatch: expected num of elements: ")
556  << validNumElems->second;
557  }
558 
559  return success();
560 }
561 
562 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
563 
564 #define GET_ATTRDEF_CLASSES
565 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
566 
567 #define GET_OP_CLASSES
568 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1218::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition: Types.cpp:45
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314