MLIR  21.0.0git
VectorToSPIRV.cpp
Go to the documentation of this file.
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Location.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/SmallVectorExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include <cassert>
35 #include <cstdint>
36 #include <numeric>
37 
38 using namespace mlir;
39 
40 /// Returns the integer value from the first valid input element, assuming Value
41 /// inputs are defined by a constant index ops and Attribute inputs are integer
42 /// attributes.
43 static uint64_t getFirstIntValue(ArrayAttr attr) {
44  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
45 }
46 
47 /// Returns the number of bits for the given scalar/vector type.
48 static int getNumBits(Type type) {
49  // TODO: This does not take into account any memory layout or widening
50  // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
51  // though in practice it will likely be stored as in a 4xi64 vector register.
52  if (auto vectorType = dyn_cast<VectorType>(type))
53  return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
54  return type.getIntOrFloatBitWidth();
55 }
56 
57 namespace {
58 
59 struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
61 
62  LogicalResult
63  matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
64  ConversionPatternRewriter &rewriter) const override {
65  Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
66  if (!dstType)
67  return failure();
68 
69  // If dstType is same as the source type or the vector size is 1, it can be
70  // directly replaced by the source.
71  if (dstType == adaptor.getSource().getType() ||
72  shapeCastOp.getResultVectorType().getNumElements() == 1) {
73  rewriter.replaceOp(shapeCastOp, adaptor.getSource());
74  return success();
75  }
76 
77  // Lowering for size-n vectors when n > 1 hasn't been implemented.
78  return failure();
79  }
80 };
81 
82 struct VectorBitcastConvert final
83  : public OpConversionPattern<vector::BitCastOp> {
85 
86  LogicalResult
87  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
88  ConversionPatternRewriter &rewriter) const override {
89  Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
90  if (!dstType)
91  return failure();
92 
93  if (dstType == adaptor.getSource().getType()) {
94  rewriter.replaceOp(bitcastOp, adaptor.getSource());
95  return success();
96  }
97 
98  // Check that the source and destination type have the same bitwidth.
99  // Depending on the target environment, we may need to emulate certain
100  // types, which can cause issue with bitcast.
101  Type srcType = adaptor.getSource().getType();
102  if (getNumBits(dstType) != getNumBits(srcType)) {
103  return rewriter.notifyMatchFailure(
104  bitcastOp,
105  llvm::formatv("different source ({0}) and target ({1}) bitwidth",
106  srcType, dstType));
107  }
108 
109  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
110  adaptor.getSource());
111  return success();
112  }
113 };
114 
115 struct VectorBroadcastConvert final
116  : public OpConversionPattern<vector::BroadcastOp> {
118 
119  LogicalResult
120  matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
121  ConversionPatternRewriter &rewriter) const override {
122  Type resultType =
123  getTypeConverter()->convertType(castOp.getResultVectorType());
124  if (!resultType)
125  return failure();
126 
127  if (isa<spirv::ScalarType>(resultType)) {
128  rewriter.replaceOp(castOp, adaptor.getSource());
129  return success();
130  }
131 
132  SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
133  adaptor.getSource());
134  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
135  source);
136  return success();
137  }
138 };
139 
140 // SPIR-V does not have a concept of a poison index for certain instructions,
141 // which creates a UB hazard when lowering from otherwise equivalent Vector
142 // dialect instructions, because this index will be considered out-of-bounds.
143 // To avoid this, this function implements a dynamic sanitization that returns
144 // some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
145 // (presumably more efficient), and otherwise index 0 (always in-bounds).
146 static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
147  Location loc, Value dynamicIndex,
148  int64_t kPoisonIndex, unsigned vectorSize) {
149  if (llvm::isPowerOf2_32(vectorSize)) {
150  Value inBoundsMask = rewriter.create<spirv::ConstantOp>(
151  loc, dynamicIndex.getType(),
152  rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
153  return rewriter.create<spirv::BitwiseAndOp>(loc, dynamicIndex,
154  inBoundsMask);
155  }
156  Value poisonIndex = rewriter.create<spirv::ConstantOp>(
157  loc, dynamicIndex.getType(),
158  rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
159  Value cmpResult =
160  rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
161  return rewriter.create<spirv::SelectOp>(
162  loc, cmpResult,
163  spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
164  dynamicIndex);
165 }
166 
167 struct VectorExtractOpConvert final
168  : public OpConversionPattern<vector::ExtractOp> {
170 
171  LogicalResult
172  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
173  ConversionPatternRewriter &rewriter) const override {
174  Type dstType = getTypeConverter()->convertType(extractOp.getType());
175  if (!dstType)
176  return failure();
177 
178  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
179  rewriter.replaceOp(extractOp, adaptor.getVector());
180  return success();
181  }
182 
183  if (std::optional<int64_t> id =
184  getConstantIntValue(extractOp.getMixedPosition()[0])) {
185  if (id == vector::ExtractOp::kPoisonIndex)
186  return rewriter.notifyMatchFailure(
187  extractOp,
188  "Static use of poison index handled elsewhere (folded to poison)");
189  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
190  extractOp, dstType, adaptor.getVector(),
191  rewriter.getI32ArrayAttr(id.value()));
192  } else {
193  Value sanitizedIndex = sanitizeDynamicIndex(
194  rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
195  vector::ExtractOp::kPoisonIndex,
196  extractOp.getSourceVectorType().getNumElements());
197  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
198  extractOp, dstType, adaptor.getVector(), sanitizedIndex);
199  }
200  return success();
201  }
202 };
203 
204 struct VectorExtractStridedSliceOpConvert final
205  : public OpConversionPattern<vector::ExtractStridedSliceOp> {
207 
208  LogicalResult
209  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
210  ConversionPatternRewriter &rewriter) const override {
211  Type dstType = getTypeConverter()->convertType(extractOp.getType());
212  if (!dstType)
213  return failure();
214 
215  uint64_t offset = getFirstIntValue(extractOp.getOffsets());
216  uint64_t size = getFirstIntValue(extractOp.getSizes());
217  uint64_t stride = getFirstIntValue(extractOp.getStrides());
218  if (stride != 1)
219  return failure();
220 
221  Value srcVector = adaptor.getOperands().front();
222 
223  // Extract vector<1xT> case.
224  if (isa<spirv::ScalarType>(dstType)) {
225  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
226  srcVector, offset);
227  return success();
228  }
229 
230  SmallVector<int32_t, 2> indices(size);
231  std::iota(indices.begin(), indices.end(), offset);
232 
233  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
234  extractOp, dstType, srcVector, srcVector,
235  rewriter.getI32ArrayAttr(indices));
236 
237  return success();
238  }
239 };
240 
241 template <class SPIRVFMAOp>
242 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
244 
245  LogicalResult
246  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
247  ConversionPatternRewriter &rewriter) const override {
248  Type dstType = getTypeConverter()->convertType(fmaOp.getType());
249  if (!dstType)
250  return failure();
251  rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
252  adaptor.getRhs(), adaptor.getAcc());
253  return success();
254  }
255 };
256 
257 struct VectorFromElementsOpConvert final
258  : public OpConversionPattern<vector::FromElementsOp> {
260 
261  LogicalResult
262  matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
263  ConversionPatternRewriter &rewriter) const override {
264  Type resultType = getTypeConverter()->convertType(op.getType());
265  if (!resultType)
266  return failure();
267  OperandRange elements = op.getElements();
268  if (isa<spirv::ScalarType>(resultType)) {
269  // In the case with a single scalar operand / single-element result,
270  // pass through the scalar.
271  rewriter.replaceOp(op, elements[0]);
272  return success();
273  }
274  // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
275  // vector.from_elements cases should not need to be handled, only 1d.
276  assert(cast<VectorType>(resultType).getRank() == 1);
277  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
278  elements);
279  return success();
280  }
281 };
282 
283 struct VectorInsertOpConvert final
284  : public OpConversionPattern<vector::InsertOp> {
286 
287  LogicalResult
288  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
289  ConversionPatternRewriter &rewriter) const override {
290  if (isa<VectorType>(insertOp.getValueToStoreType()))
291  return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
292  if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
293  return rewriter.notifyMatchFailure(insertOp,
294  "unsupported dest vector type");
295 
296  // Special case for inserting scalar values into size-1 vectors.
297  if (insertOp.getValueToStoreType().isIntOrFloat() &&
298  insertOp.getDestVectorType().getNumElements() == 1) {
299  rewriter.replaceOp(insertOp, adaptor.getValueToStore());
300  return success();
301  }
302 
303  if (std::optional<int64_t> id =
304  getConstantIntValue(insertOp.getMixedPosition()[0])) {
305  if (id == vector::InsertOp::kPoisonIndex)
306  return rewriter.notifyMatchFailure(
307  insertOp,
308  "Static use of poison index handled elsewhere (folded to poison)");
309  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
310  insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
311  } else {
312  Value sanitizedIndex = sanitizeDynamicIndex(
313  rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
314  vector::InsertOp::kPoisonIndex,
315  insertOp.getDestVectorType().getNumElements());
316  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
317  insertOp, insertOp.getDest(), adaptor.getValueToStore(),
318  sanitizedIndex);
319  }
320  return success();
321  }
322 };
323 
324 struct VectorExtractElementOpConvert final
325  : public OpConversionPattern<vector::ExtractElementOp> {
327 
328  LogicalResult
329  matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
330  ConversionPatternRewriter &rewriter) const override {
331  Type resultType = getTypeConverter()->convertType(extractOp.getType());
332  if (!resultType)
333  return failure();
334 
335  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
336  rewriter.replaceOp(extractOp, adaptor.getVector());
337  return success();
338  }
339 
340  APInt cstPos;
341  if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
342  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
343  extractOp, resultType, adaptor.getVector(),
344  rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
345  else
346  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
347  extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
348  return success();
349  }
350 };
351 
352 struct VectorInsertElementOpConvert final
353  : public OpConversionPattern<vector::InsertElementOp> {
355 
356  LogicalResult
357  matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
358  ConversionPatternRewriter &rewriter) const override {
359  Type vectorType = getTypeConverter()->convertType(insertOp.getType());
360  if (!vectorType)
361  return failure();
362 
363  if (isa<spirv::ScalarType>(vectorType)) {
364  rewriter.replaceOp(insertOp, adaptor.getSource());
365  return success();
366  }
367 
368  APInt cstPos;
369  if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
370  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
371  insertOp, adaptor.getSource(), adaptor.getDest(),
372  cstPos.getSExtValue());
373  else
374  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
375  insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
376  adaptor.getPosition());
377  return success();
378  }
379 };
380 
381 struct VectorInsertStridedSliceOpConvert final
382  : public OpConversionPattern<vector::InsertStridedSliceOp> {
384 
385  LogicalResult
386  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
387  ConversionPatternRewriter &rewriter) const override {
388  Value srcVector = adaptor.getOperands().front();
389  Value dstVector = adaptor.getOperands().back();
390 
391  uint64_t stride = getFirstIntValue(insertOp.getStrides());
392  if (stride != 1)
393  return failure();
394  uint64_t offset = getFirstIntValue(insertOp.getOffsets());
395 
396  if (isa<spirv::ScalarType>(srcVector.getType())) {
397  assert(!isa<spirv::ScalarType>(dstVector.getType()));
398  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
399  insertOp, dstVector.getType(), srcVector, dstVector,
400  rewriter.getI32ArrayAttr(offset));
401  return success();
402  }
403 
404  uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
405  uint64_t insertSize =
406  cast<VectorType>(srcVector.getType()).getNumElements();
407 
408  SmallVector<int32_t, 2> indices(totalSize);
409  std::iota(indices.begin(), indices.end(), 0);
410  std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
411  totalSize);
412 
413  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
414  insertOp, dstVector.getType(), dstVector, srcVector,
415  rewriter.getI32ArrayAttr(indices));
416 
417  return success();
418  }
419 };
420 
421 static SmallVector<Value> extractAllElements(
422  vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
423  VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
424  int numElements = static_cast<int>(srcVectorType.getDimSize(0));
425  SmallVector<Value> values;
426  values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
427  Location loc = reduceOp.getLoc();
428 
429  for (int i = 0; i < numElements; ++i) {
430  values.push_back(rewriter.create<spirv::CompositeExtractOp>(
431  loc, srcVectorType.getElementType(), adaptor.getVector(),
432  rewriter.getI32ArrayAttr({i})));
433  }
434  if (Value acc = adaptor.getAcc())
435  values.push_back(acc);
436 
437  return values;
438 }
439 
440 struct ReductionRewriteInfo {
441  Type resultType;
442  SmallVector<Value> extractedElements;
443 };
444 
445 FailureOr<ReductionRewriteInfo> static getReductionInfo(
446  vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
447  ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
448  Type resultType = typeConverter.convertType(op.getType());
449  if (!resultType)
450  return failure();
451 
452  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
453  if (!srcVectorType || srcVectorType.getRank() != 1)
454  return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
455 
456  SmallVector<Value> extractedElements =
457  extractAllElements(op, adaptor, srcVectorType, rewriter);
458 
459  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
460 }
461 
462 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
463  typename SPIRVSMinOp>
464 struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
466 
467  LogicalResult
468  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
469  ConversionPatternRewriter &rewriter) const override {
470  auto reductionInfo =
471  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
472  if (failed(reductionInfo))
473  return failure();
474 
475  auto [resultType, extractedElements] = *reductionInfo;
476  Location loc = reduceOp->getLoc();
477  Value result = extractedElements.front();
478  for (Value next : llvm::drop_begin(extractedElements)) {
479  switch (reduceOp.getKind()) {
480 
481 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
482  case vector::CombiningKind::kind: \
483  if (llvm::isa<IntegerType>(resultType)) { \
484  result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
485  } else { \
486  assert(llvm::isa<FloatType>(resultType)); \
487  result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
488  } \
489  break
490 
491 #define INT_OR_FLOAT_CASE(kind, fop) \
492  case vector::CombiningKind::kind: \
493  result = rewriter.create<fop>(loc, resultType, result, next); \
494  break
495 
496  INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
497  INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
498  INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
499  INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
500  INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
501  INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
502 
503  case vector::CombiningKind::AND:
504  case vector::CombiningKind::OR:
505  case vector::CombiningKind::XOR:
506  return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
507  default:
508  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
509  }
510 #undef INT_AND_FLOAT_CASE
511 #undef INT_OR_FLOAT_CASE
512  }
513 
514  rewriter.replaceOp(reduceOp, result);
515  return success();
516  }
517 };
518 
519 template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
520 struct VectorReductionFloatMinMax final
521  : OpConversionPattern<vector::ReductionOp> {
523 
524  LogicalResult
525  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
526  ConversionPatternRewriter &rewriter) const override {
527  auto reductionInfo =
528  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
529  if (failed(reductionInfo))
530  return failure();
531 
532  auto [resultType, extractedElements] = *reductionInfo;
533  Location loc = reduceOp->getLoc();
534  Value result = extractedElements.front();
535  for (Value next : llvm::drop_begin(extractedElements)) {
536  switch (reduceOp.getKind()) {
537 
538 #define INT_OR_FLOAT_CASE(kind, fop) \
539  case vector::CombiningKind::kind: \
540  result = rewriter.create<fop>(loc, resultType, result, next); \
541  break
542 
543  INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
544  INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
545  INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
546  INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
547 
548  default:
549  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
550  }
551 #undef INT_OR_FLOAT_CASE
552  }
553 
554  rewriter.replaceOp(reduceOp, result);
555  return success();
556  }
557 };
558 
559 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
560 public:
562 
563  LogicalResult
564  matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
565  ConversionPatternRewriter &rewriter) const override {
566  Type dstType = getTypeConverter()->convertType(op.getType());
567  if (!dstType)
568  return failure();
569  if (isa<spirv::ScalarType>(dstType)) {
570  rewriter.replaceOp(op, adaptor.getInput());
571  } else {
572  auto dstVecType = cast<VectorType>(dstType);
573  SmallVector<Value, 4> source(dstVecType.getNumElements(),
574  adaptor.getInput());
575  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
576  source);
577  }
578  return success();
579  }
580 };
581 
582 struct VectorShuffleOpConvert final
583  : public OpConversionPattern<vector::ShuffleOp> {
585 
586  LogicalResult
587  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
588  ConversionPatternRewriter &rewriter) const override {
589  VectorType oldResultType = shuffleOp.getResultVectorType();
590  Type newResultType = getTypeConverter()->convertType(oldResultType);
591  if (!newResultType)
592  return rewriter.notifyMatchFailure(shuffleOp,
593  "unsupported result vector type");
594 
595  auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
596 
597  VectorType oldV1Type = shuffleOp.getV1VectorType();
598  VectorType oldV2Type = shuffleOp.getV2VectorType();
599 
600  // When both operands and the result are SPIR-V vectors, emit a SPIR-V
601  // shuffle.
602  if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
603  oldResultType.getNumElements() > 1) {
604  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
605  shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
606  rewriter.getI32ArrayAttr(mask));
607  return success();
608  }
609 
610  // When at least one of the operands or the result becomes a scalar after
611  // type conversion for SPIR-V, extract all the required elements and
612  // construct the result vector.
613  auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
614  Value scalarOrVec, int32_t idx) -> Value {
615  if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
616  return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
617  idx);
618 
619  assert(idx == 0 && "Invalid scalar element index");
620  return scalarOrVec;
621  };
622 
623  int32_t numV1Elems = oldV1Type.getNumElements();
624  SmallVector<Value> newOperands(mask.size());
625  for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
626  Value vec = adaptor.getV1();
627  int32_t elementIdx = shuffleIdx;
628  if (elementIdx >= numV1Elems) {
629  vec = adaptor.getV2();
630  elementIdx -= numV1Elems;
631  }
632 
633  newOperand = getElementAtIdx(vec, elementIdx);
634  }
635 
636  // Handle the scalar result corner case.
637  if (newOperands.size() == 1) {
638  rewriter.replaceOp(shuffleOp, newOperands.front());
639  return success();
640  }
641 
642  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
643  shuffleOp, newResultType, newOperands);
644  return success();
645  }
646 };
647 
648 struct VectorInterleaveOpConvert final
649  : public OpConversionPattern<vector::InterleaveOp> {
651 
652  LogicalResult
653  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
654  ConversionPatternRewriter &rewriter) const override {
655  // Check the result vector type.
656  VectorType oldResultType = interleaveOp.getResultVectorType();
657  Type newResultType = getTypeConverter()->convertType(oldResultType);
658  if (!newResultType)
659  return rewriter.notifyMatchFailure(interleaveOp,
660  "unsupported result vector type");
661 
662  // Interleave the indices.
663  VectorType sourceType = interleaveOp.getSourceVectorType();
664  int n = sourceType.getNumElements();
665 
666  // Input vectors of size 1 are converted to scalars by the type converter.
667  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
668  // use `spirv::CompositeConstructOp`.
669  if (n == 1) {
670  Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
671  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
672  interleaveOp, newResultType, newOperands);
673  return success();
674  }
675 
676  auto seq = llvm::seq<int64_t>(2 * n);
677  auto indices = llvm::map_to_vector(
678  seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
679 
680  // Emit a SPIR-V shuffle.
681  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
682  interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
683  rewriter.getI32ArrayAttr(indices));
684 
685  return success();
686  }
687 };
688 
689 struct VectorDeinterleaveOpConvert final
690  : public OpConversionPattern<vector::DeinterleaveOp> {
692 
693  LogicalResult
694  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
695  ConversionPatternRewriter &rewriter) const override {
696 
697  // Check the result vector type.
698  VectorType oldResultType = deinterleaveOp.getResultVectorType();
699  Type newResultType = getTypeConverter()->convertType(oldResultType);
700  if (!newResultType)
701  return rewriter.notifyMatchFailure(deinterleaveOp,
702  "unsupported result vector type");
703 
704  Location loc = deinterleaveOp->getLoc();
705 
706  // Deinterleave the indices.
707  Value sourceVector = adaptor.getSource();
708  VectorType sourceType = deinterleaveOp.getSourceVectorType();
709  int n = sourceType.getNumElements();
710 
711  // Output vectors of size 1 are converted to scalars by the type converter.
712  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
713  // use `spirv::CompositeExtractOp`.
714  if (n == 2) {
715  auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
716  loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
717 
718  auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
719  loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
720 
721  rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
722  return success();
723  }
724 
725  // Indices for `shuffleEven` (result 0).
726  auto seqEven = llvm::seq<int64_t>(n / 2);
727  auto indicesEven =
728  llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
729 
730  // Indices for `shuffleOdd` (result 1).
731  auto seqOdd = llvm::seq<int64_t>(n / 2);
732  auto indicesOdd =
733  llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
734 
735  // Create two SPIR-V shuffles.
736  auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
737  loc, newResultType, sourceVector, sourceVector,
738  rewriter.getI32ArrayAttr(indicesEven));
739 
740  auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
741  loc, newResultType, sourceVector, sourceVector,
742  rewriter.getI32ArrayAttr(indicesOdd));
743 
744  rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
745  return success();
746  }
747 };
748 
749 struct VectorLoadOpConverter final
750  : public OpConversionPattern<vector::LoadOp> {
752 
753  LogicalResult
754  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
755  ConversionPatternRewriter &rewriter) const override {
756  auto memrefType = loadOp.getMemRefType();
757  auto attr =
758  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
759  if (!attr)
760  return rewriter.notifyMatchFailure(
761  loadOp, "expected spirv.storage_class memory space");
762 
763  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
764  auto loc = loadOp.getLoc();
765  Value accessChain =
766  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
767  adaptor.getIndices(), loc, rewriter);
768  if (!accessChain)
769  return rewriter.notifyMatchFailure(
770  loadOp, "failed to get memref element pointer");
771 
772  spirv::StorageClass storageClass = attr.getValue();
773  auto vectorType = loadOp.getVectorType();
774  // Use the converted vector type instead of original (single element vector
775  // would get converted to scalar).
776  auto spirvVectorType = typeConverter.convertType(vectorType);
777  auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
778 
779  // For single element vectors, we don't need to bitcast the access chain to
780  // the original vector type. Both is going to be the same, a pointer
781  // to a scalar.
782  Value castedAccessChain = (vectorType.getNumElements() == 1)
783  ? accessChain
784  : rewriter.create<spirv::BitcastOp>(
785  loc, vectorPtrType, accessChain);
786 
787  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
788  castedAccessChain);
789 
790  return success();
791  }
792 };
793 
794 struct VectorStoreOpConverter final
795  : public OpConversionPattern<vector::StoreOp> {
797 
798  LogicalResult
799  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
800  ConversionPatternRewriter &rewriter) const override {
801  auto memrefType = storeOp.getMemRefType();
802  auto attr =
803  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
804  if (!attr)
805  return rewriter.notifyMatchFailure(
806  storeOp, "expected spirv.storage_class memory space");
807 
808  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
809  auto loc = storeOp.getLoc();
810  Value accessChain =
811  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
812  adaptor.getIndices(), loc, rewriter);
813  if (!accessChain)
814  return rewriter.notifyMatchFailure(
815  storeOp, "failed to get memref element pointer");
816 
817  spirv::StorageClass storageClass = attr.getValue();
818  auto vectorType = storeOp.getVectorType();
819  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
820 
821  // For single element vectors, we don't need to bitcast the access chain to
822  // the original vector type. Both is going to be the same, a pointer
823  // to a scalar.
824  Value castedAccessChain = (vectorType.getNumElements() == 1)
825  ? accessChain
826  : rewriter.create<spirv::BitcastOp>(
827  loc, vectorPtrType, accessChain);
828 
829  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
830  adaptor.getValueToStore());
831 
832  return success();
833  }
834 };
835 
836 struct VectorReductionToIntDotProd final
837  : OpRewritePattern<vector::ReductionOp> {
839 
840  LogicalResult matchAndRewrite(vector::ReductionOp op,
841  PatternRewriter &rewriter) const override {
842  if (op.getKind() != vector::CombiningKind::ADD)
843  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
844 
845  auto resultType = dyn_cast<IntegerType>(op.getType());
846  if (!resultType)
847  return rewriter.notifyMatchFailure(op, "result is not an integer");
848 
849  int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
850  if (!llvm::is_contained({32, 64}, resultBitwidth))
851  return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
852 
853  VectorType inVecTy = op.getSourceVectorType();
854  if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
855  inVecTy.getShape().size() != 1 || inVecTy.isScalable())
856  return rewriter.notifyMatchFailure(op, "unsupported vector shape");
857 
858  auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
859  if (!mul)
860  return rewriter.notifyMatchFailure(
861  op, "reduction operand is not 'arith.muli'");
862 
863  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
864  spirv::SDotAccSatOp, false>(op, mul, rewriter)))
865  return success();
866 
867  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
868  spirv::UDotAccSatOp, false>(op, mul, rewriter)))
869  return success();
870 
871  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
872  spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
873  return success();
874 
875  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
876  spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
877  return success();
878 
879  return failure();
880  }
881 
882 private:
883  template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
884  typename DotAccOp, bool SwapOperands>
885  static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
886  PatternRewriter &rewriter) {
887  auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
888  if (!lhs)
889  return failure();
890  Value lhsIn = lhs.getIn();
891  auto lhsInType = cast<VectorType>(lhsIn.getType());
892  if (!lhsInType.getElementType().isInteger(8))
893  return failure();
894 
895  auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
896  if (!rhs)
897  return failure();
898  Value rhsIn = rhs.getIn();
899  auto rhsInType = cast<VectorType>(rhsIn.getType());
900  if (!rhsInType.getElementType().isInteger(8))
901  return failure();
902 
903  if (op.getSourceVectorType().getNumElements() == 3) {
904  IntegerType i8Type = rewriter.getI8Type();
905  auto v4i8Type = VectorType::get({4}, i8Type);
906  Location loc = op.getLoc();
907  Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
908  lhsIn = rewriter.create<spirv::CompositeConstructOp>(
909  loc, v4i8Type, ValueRange{lhsIn, zero});
910  rhsIn = rewriter.create<spirv::CompositeConstructOp>(
911  loc, v4i8Type, ValueRange{rhsIn, zero});
912  }
913 
914  // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
915  // we have to swap operands instead in that case.
916  if (SwapOperands)
917  std::swap(lhsIn, rhsIn);
918 
919  if (Value acc = op.getAcc()) {
920  rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
921  nullptr);
922  } else {
923  rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
924  nullptr);
925  }
926 
927  return success();
928  }
929 };
930 
931 struct VectorReductionToFPDotProd final
932  : OpConversionPattern<vector::ReductionOp> {
934 
935  LogicalResult
936  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
937  ConversionPatternRewriter &rewriter) const override {
938  if (op.getKind() != vector::CombiningKind::ADD)
939  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
940 
941  auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
942  if (!resultType)
943  return rewriter.notifyMatchFailure(op, "result is not a float");
944 
945  Value vec = adaptor.getVector();
946  Value acc = adaptor.getAcc();
947 
948  auto vectorType = dyn_cast<VectorType>(vec.getType());
949  if (!vectorType) {
950  assert(isa<FloatType>(vec.getType()) &&
951  "Expected the vector to be scalarized");
952  if (acc) {
953  rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
954  return success();
955  }
956 
957  rewriter.replaceOp(op, vec);
958  return success();
959  }
960 
961  Location loc = op.getLoc();
962  Value lhs;
963  Value rhs;
964  if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
965  lhs = mul.getLhs();
966  rhs = mul.getRhs();
967  } else {
968  // If the operand is not a mul, use a vector of ones for the dot operand
969  // to just sum up all values.
970  lhs = vec;
971  Attribute oneAttr =
972  rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
973  oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
974  rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
975  }
976  assert(lhs);
977  assert(rhs);
978 
979  Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
980  if (acc)
981  res = rewriter.create<spirv::FAddOp>(loc, acc, res);
982 
983  rewriter.replaceOp(op, res);
984  return success();
985  }
986 };
987 
988 struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
990 
991  LogicalResult
992  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
993  ConversionPatternRewriter &rewriter) const override {
994  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
995  Type dstType = typeConverter.convertType(stepOp.getType());
996  if (!dstType)
997  return failure();
998 
999  Location loc = stepOp.getLoc();
1000  int64_t numElements = stepOp.getType().getNumElements();
1001  auto intType =
1002  rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
1003 
1004  // Input vectors of size 1 are converted to scalars by the type converter.
1005  // We just create a constant in this case.
1006  if (numElements == 1) {
1007  Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
1008  rewriter.replaceOp(stepOp, zero);
1009  return success();
1010  }
1011 
1012  SmallVector<Value> source;
1013  source.reserve(numElements);
1014  for (int64_t i = 0; i < numElements; ++i) {
1015  Attribute intAttr = rewriter.getIntegerAttr(intType, i);
1016  Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
1017  source.push_back(constOp);
1018  }
1019  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
1020  source);
1021  return success();
1022  }
1023 };
1024 
1025 } // namespace
1026 #define CL_INT_MAX_MIN_OPS \
1027  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1028 
1029 #define GL_INT_MAX_MIN_OPS \
1030  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1031 
1032 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1033 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1034 
1036  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1037  patterns.add<
1038  VectorBitcastConvert, VectorBroadcastConvert,
1039  VectorExtractElementOpConvert, VectorExtractOpConvert,
1040  VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1041  VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1042  VectorInsertElementOpConvert, VectorInsertOpConvert,
1043  VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1044  VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1045  VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1046  VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1047  VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1048  VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1049  VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
1050  VectorStepOpConvert>(typeConverter, patterns.getContext(),
1051  PatternBenefit(1));
1052 
1053  // Make sure that the more specialized dot product pattern has higher benefit
1054  // than the generic one that extracts all elements.
1055  patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
1056  PatternBenefit(2));
1057 }
1058 
1061  patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
1062 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
#define MINUI(lhs, rhs)
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:272
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
IntegerType getI8Type()
Definition: Builders.cpp:59
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion from builtin types to SPIR-V types for shader interface.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319