MLIR  21.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Diagnostics.h"
23 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 #include <limits>
30 #include <optional>
31 
32 using namespace mlir;
33 using namespace mlir::amdgpu;
34 
35 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
36 
37 void AMDGPUDialect::initialize() {
38  addOperations<
39 #define GET_OP_LIST
40 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
41  >();
42  addAttributes<
43 #define GET_ATTRDEF_LIST
44 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
45  >();
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // 8-bit float ops
50 //===----------------------------------------------------------------------===//
51 LogicalResult PackedTrunc2xFp8Op::verify() {
52  if (getExisting() && getExisting().getType() != getResult().getType())
53  return emitOpError("existing values must have same type as result");
54  return success();
55 }
56 
57 LogicalResult PackedStochRoundFp8Op::verify() {
58  if (getExisting() && getExisting().getType() != getResult().getType())
59  return emitOpError("existing values must have same type as result");
60  return success();
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // mxfp float ops
65 //===----------------------------------------------------------------------===//
66 LogicalResult PackedScaledTruncOp::verify() {
67  if (getExisting() && getExisting().getType() != getResult().getType())
68  return emitOpError("existing values must have same type as result");
69  return success();
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // FatRawBufferCastOp
74 //===----------------------------------------------------------------------===//
75 
76 /// Convert the type `source` to one with the same sizes and strides - and
77 /// offset, unless `stripOffset` is true, in which case the offset is reset to
78 /// 0, if the offset should be reset but the layout of `source` isn't either the
79 /// identity layout or a strided layout, this function fails.
80 static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
81  bool resetOffset) {
82  MLIRContext *ctx = source.getContext();
83  MemRefType::Builder mb(source);
84  mb.setMemorySpace(
85  amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
86  MemRefLayoutAttrInterface layout = source.getLayout();
87  if (resetOffset && !layout.isIdentity()) {
88  auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
89  if (!stridedLayout)
90  return failure();
91  mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
92  }
93  return (MemRefType)(mb);
94 }
95 
96 LogicalResult FatRawBufferCastOp::inferReturnTypes(
97  MLIRContext *context, std::optional<Location> location, ValueRange operands,
98  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
99  SmallVectorImpl<Type> &inferredReturnTypes) {
100  Adaptor adaptor(operands, attributes, properties, regions);
101  auto sourceType =
102  dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
103  if (!sourceType)
104  return failure();
105  FailureOr<MemRefType> resultType =
106  getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
107  if (failed(resultType))
108  return failure();
109  inferredReturnTypes = SmallVector<Type>{*resultType};
110  return success();
111 }
112 
113 LogicalResult FatRawBufferCastOp::verify() {
114  FailureOr<MemRefType> expectedResultType =
115  getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
116  if (failed(expectedResultType))
117  return emitOpError("source type ")
118  << getSource().getType() << " can't have its offset reset";
119  if (getResult().getType() != *expectedResultType)
120  return emitOpError("expected result type to be ")
121  << *expectedResultType << " but got " << getResult().getType();
122  return success();
123 }
124 
125 static bool hasGlobalMemorySpace(Attribute memorySpace) {
126  if (!memorySpace)
127  return true;
128  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
129  return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
130  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
131  return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
132  return false;
133 }
134 
135 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
136  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
137  return intMemorySpace.getInt() == 3;
138  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
139  return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
140  return false;
141 }
142 
143 static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
144  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
145  return intMemorySpace.getInt() == 7;
146  if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
147  return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
148  return false;
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // RawBuffer*Op
153 //===----------------------------------------------------------------------===//
154 template <typename T>
155 static LogicalResult verifyRawBufferOp(T &op) {
156  MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
157  bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
158 
159  if (!isGlobal)
160  return op.emitOpError(
161  "Buffer ops must operate on a memref in global memory");
162  if (!bufferType.hasRank())
163  return op.emitOpError(
164  "Cannot meaningfully buffer_store to an unranked memref");
165  if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
166  return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
167  " indices to memref");
168  return success();
169 }
170 
171 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
172 
173 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
174 
175 LogicalResult RawBufferAtomicFaddOp::verify() {
176  return verifyRawBufferOp(*this);
177 }
178 
179 LogicalResult RawBufferAtomicFmaxOp::verify() {
180  return verifyRawBufferOp(*this);
181 }
182 
183 LogicalResult RawBufferAtomicSmaxOp::verify() {
184  return verifyRawBufferOp(*this);
185 }
186 
187 LogicalResult RawBufferAtomicUminOp::verify() {
188  return verifyRawBufferOp(*this);
189 }
190 
191 LogicalResult RawBufferAtomicCmpswapOp::verify() {
192  return verifyRawBufferOp(*this);
193 }
194 
195 static std::optional<uint32_t> getConstantUint32(Value v) {
196  APInt cst;
197  if (!v.getType().isInteger(32))
198  return std::nullopt;
199  if (matchPattern(v, m_ConstantInt(&cst)))
200  return cst.getZExtValue();
201  return std::nullopt;
202 }
203 
204 template <typename OpType>
205 static bool staticallyOutOfBounds(OpType op) {
206  if (!op.getBoundsCheck())
207  return false;
208  MemRefType bufferType = op.getMemref().getType();
209  if (!bufferType.hasStaticShape())
210  return false;
211  int64_t offset;
212  SmallVector<int64_t> strides;
213  if (failed(bufferType.getStridesAndOffset(strides, offset)))
214  return false;
215  int64_t result = offset + op.getIndexOffset().value_or(0);
216  if (op.getSgprOffset()) {
217  std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
218  if (!sgprOffset)
219  return false;
220  result += *sgprOffset;
221  }
222  if (strides.size() != op.getIndices().size())
223  return false;
224  int64_t indexVal = 0;
225  for (auto pair : llvm::zip(strides, op.getIndices())) {
226  int64_t stride = std::get<0>(pair);
227  Value idx = std::get<1>(pair);
228  std::optional<uint32_t> idxVal = getConstantUint32(idx);
229  if (!idxVal)
230  return false;
231  indexVal += stride * *idxVal;
232  }
233  result += indexVal;
234  if (result > std::numeric_limits<uint32_t>::max())
235  // Overflow means don't drop
236  return false;
237  return result >= bufferType.getNumElements();
238 }
239 
240 namespace {
241 template <typename OpType>
242 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
244 
245  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
246  if (!staticallyOutOfBounds(op))
247  return failure();
248  Type loadType = op.getResult().getType();
249  rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
250  rw.getZeroAttr(loadType));
251  return success();
252  }
253 };
254 
255 template <typename OpType>
256 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
258 
259  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
260  if (!staticallyOutOfBounds(op))
261  return failure();
262 
263  rw.eraseOp(op);
264  return success();
265  }
266 };
267 } // end namespace
268 
269 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
270  MLIRContext *context) {
271  results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
272 }
273 
274 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
275  MLIRContext *context) {
276  results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
277 }
278 
279 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
280  RewritePatternSet &results, MLIRContext *context) {
281  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
282 }
283 
284 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
285  RewritePatternSet &results, MLIRContext *context) {
286  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
287 }
288 
289 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
290  RewritePatternSet &results, MLIRContext *context) {
291  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
292 }
293 
294 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
295  RewritePatternSet &results, MLIRContext *context) {
296  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
297 }
298 
299 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
300  RewritePatternSet &results, MLIRContext *context) {
301  results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
302  context);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // WMMAOp
307 //===----------------------------------------------------------------------===//
308 LogicalResult WMMAOp::verify() {
309  Type sourceAType = getSourceA().getType();
310  Type sourceBType = getSourceB().getType();
311  Type destType = getDestC().getType();
312 
313  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
314  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
315  VectorType destVectorType = dyn_cast<VectorType>(destType);
316 
317  Type sourceAElemType = sourceVectorAType.getElementType();
318  Type sourceBElemType = sourceVectorBType.getElementType();
319  Type destElemType = destVectorType.getElementType();
320 
321  if (sourceVectorAType.getNumElements() !=
322  sourceVectorBType.getNumElements()) {
323  return emitOpError("source vectors have different lengths: ")
324  << sourceVectorAType << " vs. " << sourceVectorBType;
325  }
326 
327  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
328  bool isSrcFloat =
329  isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
330  sourceAElemType);
331 
332  if (isDestFloat && !isSrcFloat) {
333  return emitOpError("Expected float sources with float destination");
334  }
335 
336  if (!isDestFloat && isSrcFloat) {
337  return emitOpError("Expected int sources with int destination");
338  }
339 
340  if (sourceAElemType != sourceBElemType &&
341  !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
342  isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
343  return emitOpError(
344  "source element types much match (except for fp8) but have ")
345  << sourceAType << " and " << sourceBType;
346  }
347  return success();
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // MFMAOp
352 //===----------------------------------------------------------------------===//
353 LogicalResult MFMAOp::verify() {
354  constexpr uint32_t waveSize = 64;
355  Builder b(getContext());
356 
357  Type sourceType = getSourceA().getType();
358  Type destType = getDestC().getType();
359 
360  Type sourceElem = sourceType, destElem = destType;
361  uint32_t sourceLen = 1, destLen = 1;
362  if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
363  sourceLen = sourceVector.getNumElements();
364  sourceElem = sourceVector.getElementType();
365  }
366  if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
367  destLen = destVector.getNumElements();
368  destElem = destVector.getElementType();
369  }
370 
371  Type sourceBType = getSourceB().getType();
372  if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
373  int64_t sourceBLen = 1;
374  Type sourceBElem = sourceBType;
375  if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
376  sourceBLen = sourceBVector.getNumElements();
377  sourceBElem = sourceBVector.getElementType();
378  }
379  if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
380  !sourceBElem.isFloat(4))
381  return emitOpError("expected both source operands to have small-float "
382  "elements if one does");
383  if (sourceLen != sourceBLen)
384  return emitOpError(
385  "expected both small-float source vectors to have the same length");
386  } else {
387  if (sourceType != sourceBType)
388  return emitOpError("expected both non-small-float source operand types "
389  "to match exactly");
390  }
391  // Normalize the wider integer types the compiler expects to i8
392  if (sourceElem.isInteger(32)) {
393  sourceLen *= 4;
394  sourceElem = b.getI8Type();
395  }
396  if (sourceElem.isInteger(64)) {
397  sourceLen *= 8;
398  sourceElem = b.getI8Type();
399  }
400 
401  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
402  if (sourceLen != numSourceElems)
403  return emitOpError("expected " + Twine(numSourceElems) +
404  " source values for this operation but got " +
405  Twine(sourceLen));
406 
407  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
408  if (destLen != numDestElems)
409  return emitOpError("expected " + Twine(numDestElems) +
410  " result values for this operation but got " +
411  Twine(destLen));
412 
413  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
414  return emitOpError(
415  "double-precision ops do not support permuting lanes of B");
416  if (destElem.isF64() && getCbsz() != 0)
417  return emitOpError(
418  "double-precision ops do not support permuting lanes of A");
419  if (getAbid() >= (1u << getCbsz()))
420  return emitOpError(
421  "block ID for permuting A (abid) must be below 2 ** cbsz");
422 
423  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
424  return emitOpError(
425  "negation flags only available for double-precision operations");
426 
427  return success();
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // DPPOp
432 //===----------------------------------------------------------------------===//
433 LogicalResult DPPOp::verify() {
434  Type srcType = getSrc().getType();
435  if (srcType.getIntOrFloatBitWidth() > 64) {
436  return emitOpError("integer and floating point types larger than 64 bits "
437  "are not supported");
438  }
439 
440  DPPPerm kind = getKind();
441  Attribute permArgument = getPermArgument().value_or(Attribute{});
442 
443  switch (kind) {
444 
445  case DPPPerm::quad_perm: {
446  auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
447  if (!quadPermAttr || quadPermAttr.size() != 4) {
448  return emitOpError("quad_perm attribute must have exactly 4 elements");
449  }
450  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
451  int32_t num = elem.getInt();
452  if (num < 0 || num > 3) {
453  return emitOpError(
454  "Each element of quad_perm must be in the range [0, 3]");
455  }
456  }
457  } break;
458 
459  case DPPPerm::row_shl:
460  case DPPPerm::row_shr:
461  case DPPPerm::row_ror: {
462  if (!permArgument) {
463  return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
464  "' value not specified");
465  }
466  if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
467  uint32_t attrValue = intAttr.getInt();
468  if (attrValue < 1 || attrValue > 15) {
469  return emitOpError("Attribute value must be between 1 and 15");
470  }
471  }
472  } break;
473 
474  case DPPPerm::wave_shl:
475  case DPPPerm::wave_shr:
476  case DPPPerm::wave_rol:
477  case DPPPerm::wave_ror:
478  case DPPPerm::row_mirror:
479  case DPPPerm::row_half_mirror:
480  case DPPPerm::row_bcast_15:
481  case DPPPerm::row_bcast_31: {
482  if (permArgument && !isa<UnitAttr>(permArgument)) {
483  return emitOpError("Expected unit attribute for permArgument, but found "
484  "non-trivial argument");
485  }
486  break;
487  }
488  }
489  return success();
490 }
491 
492 LogicalResult GatherToLDSOp::verify() {
493  MemRefType srcType = cast<MemRefType>(getSrc().getType());
494  MemRefType dstType = cast<MemRefType>(getDst().getType());
495 
496  if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
497  return emitOpError("destination types must be contiguous");
498 
499  auto elemType = srcType.getElementType();
500  // Check $src and $dst element types are the same.
501  if (elemType != dstType.getElementType())
502  return emitOpError("source and destination element types must match");
503 
504  // copy type sizes should be 1, 2, or 4 bytes.
505  auto transferType = getTransferType();
506  size_t transferSize;
507  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
508  transferSize = vectorTransfer.getNumElements() *
509  vectorTransfer.getElementTypeBitWidth();
510  } else {
511  transferSize = transferType.getIntOrFloatBitWidth();
512  }
513  if (transferSize != 8 && transferSize != 16 && transferSize != 32)
514  return emitOpError("Transfering type size must be 8, 16, or 32 bits");
515 
516  if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
517  !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
518  return emitOpError(
519  "source memory address space must be global or fat raw buffer");
520 
521  if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
522  return emitOpError("destination memory address space must be Workgroup");
523 
524  return success();
525 }
526 
527 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
528 
529 #define GET_ATTRDEF_CLASSES
530 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
531 
532 #define GET_OP_CLASSES
533 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:322
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:166
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:187
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:192
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition: Types.cpp:45
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314