MLIR  20.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // TOSA canonicalization patterns and folders.
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <functional>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 //===----------------------------------------------------------------------===//
39 // Operator Canonicalizers.
40 //===----------------------------------------------------------------------===//
41 
42 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
44 
45  LogicalResult matchAndRewrite(tosa::ConcatOp op,
46  PatternRewriter &rewriter) const override {
47  if (op.getInput1().size() != 1)
48  return failure();
49  if (op.getInput1().front().getType() != op.getType()) {
50  rewriter
51  .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
52  op.getInput1().front())
53  .getResult();
54  return success();
55  }
56 
57  rewriter.replaceOp(op, op.getInput1().front());
58  return success();
59  }
60 };
61 
62 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
63  MLIRContext *context) {
64  results.add<ConcatOptimization>(context);
65 }
66 
67 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
68  auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
69  if (!notOp)
70  return failure();
71  rewriter.modifyOpInPlace(op, [&]() {
72  op.getOperation()->setOperands(
73  {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
74  });
75  return success();
76 }
77 
79  : public OpRewritePattern<tosa::TransposeOp> {
81 
82  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
83  PatternRewriter &rewriter) const override {
84  // Input is also TransposeOp - transpose(transpose(A)).
85  auto innerTranspose =
86  transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
87  if (!innerTranspose)
88  return rewriter.notifyMatchFailure(transposeOp,
89  "input must be transpose operation");
90 
91  SmallVector<int32_t> transposePerms, innerTransposePerms;
92  if (transposeOp.getConstantPerms(transposePerms).failed())
93  return rewriter.notifyMatchFailure(transposeOp,
94  "transpose perms must be constant");
95  if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
96  return rewriter.notifyMatchFailure(
97  transposeOp, "inner transpose perms must be constant");
98  if (transposePerms.size() != innerTransposePerms.size())
99  return rewriter.notifyMatchFailure(
100  transposeOp,
101  "transpose and inner transpose perms sizes must be equal");
102  if (transposePerms.empty())
103  return rewriter.notifyMatchFailure(
104  transposeOp, "transpose perms sizes must be positive");
105 
106  // Consolidate transposes into one transpose.
107  SmallVector<int32_t> perms(transposePerms.size());
108  for (int i = 0, s = transposePerms.size(); i < s; ++i)
109  perms[i] = innerTransposePerms[transposePerms[i]];
110 
111  auto permsTy =
112  RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
113  auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
114  Value permsValue =
115  rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
116 
117  rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
118  transposeOp, transposeOp.getResult().getType(),
119  innerTranspose.getInput1(), permsValue);
120 
121  return success();
122  }
123 };
124 
125 // Determines the case when tosa.transpose is a tosa.reshape operation.
126 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
128 
129  LogicalResult matchAndRewrite(tosa::TransposeOp op,
130  PatternRewriter &rewriter) const override {
131  DenseIntElementsAttr permAttr;
132  if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
133  return rewriter.notifyMatchFailure(op, "Non-constant permutation");
134 
135  if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
136  return rewriter.notifyMatchFailure(
137  op, "Src is from transpose, can compose transposes");
138 
139  Value result = op.getResult();
140  for (Operation *subop : result.getUsers()) {
141  if (dyn_cast_or_null<tosa::TransposeOp>(subop))
142  return rewriter.notifyMatchFailure(
143  op, "Dest is used by transpose, can compose transposes");
144  }
145 
146  auto input = op.getInput1();
147  auto inputTy = llvm::cast<ShapedType>(input.getType());
148  if (!inputTy.hasRank())
149  return rewriter.notifyMatchFailure(op, "Unranked input.");
150 
151  int64_t numDynDims = 0;
152  for (int i = 0; i < inputTy.getRank(); ++i)
153  if (inputTy.isDynamicDim(i))
154  numDynDims++;
155 
156  if (numDynDims > 1)
157  return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
158 
159  SmallVector<int64_t> permValues = llvm::to_vector<6>(
160  llvm::map_range(permAttr.getValues<APInt>(),
161  [](const APInt &val) { return val.getSExtValue(); }));
162 
163  SmallVector<int64_t> nonZeroPerms;
164  nonZeroPerms.reserve(permValues.size());
165  for (auto idx : permValues) {
166  auto sz = inputTy.getDimSize(idx);
167  if (sz != 1)
168  nonZeroPerms.push_back(idx);
169  }
170 
171  for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
172  if (nonZeroPerms[i - 1] > nonZeroPerms[i])
173  return rewriter.notifyMatchFailure(op,
174  "Transpose changes memory layout.");
175 
176  SmallVector<int64_t> newShape;
177  newShape.reserve(inputTy.getRank());
178  for (int i = 0, s = inputTy.getRank(); i < s; ++i)
179  newShape.push_back(inputTy.getDimSize(permValues[i]));
180 
181  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
182  op, op.getType(), op.getInput1(),
183  rewriter.getDenseI64ArrayAttr(newShape));
184  return success();
185  }
186 };
187 
188 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
189  MLIRContext *context) {
191 }
192 
193 struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
195 
196  LogicalResult matchAndRewrite(tosa::PadOp op,
197  PatternRewriter &rewriter) const override {
198  if (op.getPadConst())
199  return failure();
200 
201  auto input = op.getInput1();
202  auto padding = op.getPadding();
203 
204  ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
205  Type elementTy = inputTy.getElementType();
206 
207  Attribute constantAttr;
208  if (llvm::isa<FloatType>(elementTy)) {
209  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
210  } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
211  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
212  } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
213  auto value = op.getQuantizationInfo()->getInputZp();
214  constantAttr = rewriter.getIntegerAttr(elementTy, value);
215  }
216 
217  if (!constantAttr) {
218  return rewriter.notifyMatchFailure(
219  op,
220  "tosa.pad to linalg lowering encountered an unknown element type");
221  }
222 
223  auto denseAttr = DenseElementsAttr::get(
224  RankedTensorType::get({}, elementTy), constantAttr);
225  auto constantVal = rewriter.create<tosa::ConstOp>(
226  op.getLoc(), denseAttr.getType(), denseAttr);
227 
228  rewriter.replaceOpWithNewOp<tosa::PadOp>(
229  op, op.getType(), ValueRange{input, padding, constantVal},
230  op->getAttrs());
231  return success();
232  }
233 };
234 
235 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
236  MLIRContext *context) {
237  results.add<MaterializePadValue>(context);
238 }
239 
240 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
242 
243  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
244  PatternRewriter &rewriter) const override {
245  Value input = op.getInput();
246  Value output = op.getOutput();
247  ShapedType inputType = llvm::cast<ShapedType>(input.getType());
248  ShapedType outputType = llvm::cast<ShapedType>(output.getType());
249 
250  if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
251  return failure();
252  }
253 
254  // If the output and input shapes are 1x1, then this is a no op.
255  ArrayRef<int64_t> outputShape = outputType.getShape();
256  if (outputShape[1] != 1 || outputShape[2] != 1) {
257  return failure();
258  }
259 
260  ArrayRef<int64_t> inputShape = inputType.getShape();
261  if (inputShape[1] != 1 || inputShape[2] != 1) {
262  return failure();
263  }
264 
265  rewriter.replaceOp(op, input);
266  return success();
267  }
268 };
269 
270 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
271  MLIRContext *context) {
272  results.add<MaxPool2dIsNoOp>(context);
273 }
274 
275 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
277 
278  LogicalResult matchAndRewrite(tosa::ClampOp op,
279  PatternRewriter &rewriter) const override {
280  Value input = op.getInput();
281  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
282  auto inputElementType = inputType.getElementType();
283 
284  if (!inputType.hasStaticShape()) {
285  return failure();
286  }
287 
288  if (isa<FloatType>(inputElementType)) {
289  // Unlike integer types, floating point types can represent infinity.
290  auto minClamp = op.getMinFp();
291  auto maxClamp = op.getMaxFp();
292  bool isMin = minClamp.isInfinity() && minClamp.isNegative();
293  bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
294 
295  if (isMin && isMax) {
296  rewriter.replaceOp(op, input);
297  return success();
298  }
299  return failure();
300  }
301 
302  if (inputElementType.isUnsignedInteger()) {
303  int64_t minClamp = op.getMinInt();
304  int64_t maxClamp = op.getMaxInt();
305 
306  int64_t intMin =
307  APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
308  .getZExtValue();
309  int64_t intMax =
310  APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
311  .getZExtValue();
312 
313  if (minClamp <= intMin && maxClamp >= intMax) {
314  rewriter.replaceOp(op, input);
315  return success();
316  }
317  return failure();
318  }
319 
320  if (llvm::isa<IntegerType>(inputElementType)) {
321  int64_t minClamp = op.getMinInt();
322  int64_t maxClamp = op.getMaxInt();
323 
324  int64_t intMin =
325  APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
326  .getSExtValue();
327  int64_t intMax =
328  APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
329  .getSExtValue();
330 
331  if (minClamp <= intMin && maxClamp >= intMax) {
332  rewriter.replaceOp(op, input);
333  return success();
334  }
335  return failure();
336  }
337 
338  return failure();
339  }
340 };
341 
342 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
344 
345  LogicalResult matchAndRewrite(tosa::ClampOp op,
346  PatternRewriter &rewriter) const override {
347  Value input = op.getInput();
348 
349  Operation *definingOp = input.getDefiningOp();
350  if (!definingOp)
351  return failure();
352 
353  if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354  auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
355  auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
356 
357  auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
358  auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
359 
360  rewriter.replaceOpWithNewOp<tosa::ClampOp>(
361  op, op.getType(), clampOp.getInput(),
362  rewriter.getI64IntegerAttr(minInt),
363  rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
364  rewriter.getF32FloatAttr(maxFp));
365  return success();
366  }
367 
368  return failure();
369  }
370 };
371 
372 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
373  MLIRContext *context) {
374  results.add<ClampIsNoOp>(context);
375  results.add<ClampClampOptimization>(context);
376 }
377 
378 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
380 
381  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
382  PatternRewriter &rewriter) const override {
383  Value sliceInput = sliceOp.getInput1();
384  auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
385  if (!concatOp)
386  return rewriter.notifyMatchFailure(
387  sliceOp, "slice input must be concat operation");
388 
389  OperandRange inputs = concatOp.getInput1();
390  auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
391  if (!concatType || !concatType.hasStaticShape())
392  return rewriter.notifyMatchFailure(
393  sliceOp, "slice input must be a static ranked tensor");
394  int32_t axis = concatOp.getAxis();
395 
396  llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
397  llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
398 
399  // Validate slice on the concatenated axis. Slicing along this
400  // axis should span only one of the inputs to the concatenate
401  // operation.
402  std::optional<Value> replaceWithSlice;
403  for (auto input : inputs) {
404  auto inputType = dyn_cast<RankedTensorType>(input.getType());
405  if (!inputType || !inputType.hasStaticShape())
406  return rewriter.notifyMatchFailure(
407  sliceOp, "concat input must be a static ranked tensor");
408 
409  if (sliceStart[axis] >= 0 &&
410  (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
411  replaceWithSlice = rewriter
412  .create<tosa::SliceOp>(
413  sliceOp.getLoc(), sliceOp.getType(), input,
414  rewriter.getDenseI64ArrayAttr(sliceStart),
415  rewriter.getDenseI64ArrayAttr(sliceSize))
416  .getResult();
417  break;
418  }
419  sliceStart[axis] -= inputType.getDimSize(axis);
420  }
421 
422  if (!replaceWithSlice)
423  return rewriter.notifyMatchFailure(
424  sliceOp, "corresponding concat input not found for slice");
425 
426  rewriter.replaceOp(sliceOp, replaceWithSlice.value());
427  return success();
428  }
429 };
430 
431 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
432  MLIRContext *context) {
433  results.add<ConcatSliceOptimization>(context);
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Operator Folders.
438 //===----------------------------------------------------------------------===//
439 
440 template <typename IntFolder, typename FloatFolder>
442  RankedTensorType returnTy) {
443  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
444  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
445  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
446  if (lETy != rETy)
447  return {};
448 
449  if (llvm::isa<IntegerType>(lETy)) {
450  APInt l = lhs.getSplatValue<APInt>();
451  APInt r = rhs.getSplatValue<APInt>();
452  auto result = IntFolder()(l, r);
453  return DenseElementsAttr::get(returnTy, result);
454  }
455 
456  if (llvm::isa<FloatType>(lETy)) {
457  APFloat l = lhs.getSplatValue<APFloat>();
458  APFloat r = rhs.getSplatValue<APFloat>();
459  auto result = FloatFolder()(l, r);
460  return DenseElementsAttr::get(returnTy, result);
461  }
462  }
463 
464  return {};
465 }
466 
467 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
468  if (llvm::isa<FloatType>(elemType))
469  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
470  if (llvm::isa<IntegerType>(elemType))
471  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
472  return false;
473 }
474 
475 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
476  if (llvm::isa<FloatType>(elemType))
477  return val && val.isSplat() &&
478  val.getSplatValue<APFloat>().isExactlyValue(1.0);
479  if (llvm::isa<IntegerType>(elemType)) {
480  const int64_t shifted = 1LL << shift;
481  return val && val.isSplat() &&
482  val.getSplatValue<APInt>().getSExtValue() == shifted;
483  }
484  return false;
485 }
486 
487 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
488  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
489  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
490  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
491  if (!lhsTy || !rhsTy || !resultTy)
492  return {};
493 
494  // Cannot create an ElementsAttr from non-int/float/index types
495  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
496  !rhsTy.getElementType().isIntOrIndexOrFloat())
497  return {};
498 
499  auto resultETy = resultTy.getElementType();
500  auto lhsAttr =
501  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
502  auto rhsAttr =
503  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
504 
505  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
506  return getInput1();
507  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
508  return getInput2();
509 
510  if (!lhsAttr || !rhsAttr)
511  return {};
512 
513  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
514  resultTy);
515 }
516 
517 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
518  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
519  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
520  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
521  !outputTy.hasStaticShape())
522  return {};
523 
524  if (inputTy.getDimSize(getAxis()) == 1)
525  return DenseElementsAttr::get(outputTy, 0);
526 
527  return {};
528 }
529 
530 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
531  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
532  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
533  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
534  if (!lhsTy || !rhsTy || !resultTy)
535  return {};
536  if (lhsTy != rhsTy)
537  return {};
538 
539  // IntDivOp inputs must be integer type, no need to check for quantized type
540  auto resultETy = resultTy.getElementType();
541  auto lhsAttr =
542  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
543  auto rhsAttr =
544  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
545  if (lhsAttr && lhsAttr.isSplat()) {
546  if (llvm::isa<IntegerType>(resultETy) &&
547  lhsAttr.getSplatValue<APInt>().isZero())
548  return lhsAttr;
549  }
550 
551  if (rhsAttr && rhsAttr.isSplat()) {
552  if (llvm::isa<IntegerType>(resultETy) &&
553  rhsAttr.getSplatValue<APInt>().isOne())
554  return getInput1();
555  }
556 
557  if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
558  if (llvm::isa<IntegerType>(resultETy)) {
559  APInt l = lhsAttr.getSplatValue<APInt>();
560  APInt r = rhsAttr.getSplatValue<APInt>();
561  APInt result = l.sdiv(r);
562  return DenseElementsAttr::get(resultTy, result);
563  }
564  }
565 
566  return {};
567 }
568 
569 namespace {
570 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
571  RankedTensorType ty, int32_t shift) {
572  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
573  if (llvm::isa<IntegerType>(ty.getElementType())) {
574  APInt l = lhs.getSplatValue<APInt>();
575  APInt r = rhs.getSplatValue<APInt>();
576 
577  if (shift == 0) {
578  return DenseElementsAttr::get(ty, l * r);
579  }
580 
581  auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
582  l = l.sext(bitwidth * 2);
583  r = r.sext(bitwidth * 2);
584  auto result = l * r;
585  result.lshrInPlace(shift);
586  result = result.trunc(bitwidth);
587  return DenseElementsAttr::get(ty, result);
588  }
589 
590  if (llvm::isa<FloatType>(ty.getElementType())) {
591  APFloat l = lhs.getSplatValue<APFloat>();
592  APFloat r = rhs.getSplatValue<APFloat>();
593  APFloat result = l * r;
594  return DenseElementsAttr::get(ty, result);
595  }
596  }
597 
598  return {};
599 }
600 } // namespace
601 
602 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
603  auto lhs = getInput1();
604  auto rhs = getInput2();
605  auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
606  auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
607  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
608  if (!lhsTy || !rhsTy || !resultTy)
609  return {};
610 
611  auto resultETy = resultTy.getElementType();
612  auto lhsAttr =
613  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
614  auto rhsAttr =
615  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
616 
617  const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
618 
619  if (rhsTy == resultTy) {
620  if (isSplatZero(resultETy, lhsAttr))
621  return lhsAttr.resizeSplat(resultTy);
622  if (isSplatOne(resultETy, lhsAttr, shift))
623  return rhs;
624  }
625  if (lhsTy == resultTy) {
626  if (isSplatZero(resultETy, rhsAttr))
627  return rhsAttr.resizeSplat(resultTy);
628  if (isSplatOne(resultETy, rhsAttr, shift))
629  return lhs;
630  }
631 
632  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
633 }
634 
635 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
636  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
637  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
638  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
639  if (!lhsTy || !rhsTy || !resultTy)
640  return {};
641 
642  // Cannot create an ElementsAttr from non-int/float/index types
643  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
644  !rhsTy.getElementType().isIntOrIndexOrFloat())
645  return {};
646 
647  auto resultETy = resultTy.getElementType();
648  auto lhsAttr =
649  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
650  auto rhsAttr =
651  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
652 
653  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
654  return getInput1();
655 
656  if (!lhsAttr || !rhsAttr)
657  return {};
658 
659  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
660  resultTy);
661 }
662 
663 namespace {
664 template <typename Cmp>
665 struct ComparisonFold {
666  ComparisonFold() = default;
667  APInt operator()(const APInt &l, const APInt &r) {
668  return APInt(1, Cmp()(l, r));
669  }
670 
671  APInt operator()(const APFloat &l, const APFloat &r) {
672  return APInt(1, Cmp()(l, r));
673  }
674 };
675 
676 struct APIntFoldGreater {
677  APIntFoldGreater() = default;
678  APInt operator()(const APInt &l, const APInt &r) {
679  return APInt(1, l.sgt(r));
680  }
681 };
682 
683 struct APIntFoldGreaterEqual {
684  APIntFoldGreaterEqual() = default;
685  APInt operator()(const APInt &l, const APInt &r) {
686  return APInt(1, l.sge(r));
687  }
688 };
689 } // namespace
690 
691 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
692  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
693  auto lhsAttr =
694  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
695  auto rhsAttr =
696  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
697 
698  if (!lhsAttr || !rhsAttr)
699  return {};
700 
701  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
702  lhsAttr, rhsAttr, resultTy);
703 }
704 
705 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
706  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
707  auto lhsAttr =
708  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
709  auto rhsAttr =
710  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
711 
712  if (!lhsAttr || !rhsAttr)
713  return {};
714 
715  return binaryFolder<APIntFoldGreaterEqual,
716  ComparisonFold<std::greater_equal<APFloat>>>(
717  lhsAttr, rhsAttr, resultTy);
718 }
719 
720 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
721  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
722  auto lhsAttr =
723  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
724  auto rhsAttr =
725  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
726  Value lhs = getInput1();
727  Value rhs = getInput2();
728  auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
729 
730  // If we are comparing an integer value to itself it is always true. We can
731  // not do this with float due to float values.
732  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
733  resultTy.hasStaticShape() && lhs == rhs) {
734  return DenseElementsAttr::get(resultTy, true);
735  }
736 
737  if (!lhsAttr || !rhsAttr)
738  return {};
739 
740  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
741  ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
742  resultTy);
743 }
744 
745 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
746  if (getInput().getType() == getType())
747  return getInput();
748 
749  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
750  if (!operand)
751  return {};
752 
753  auto inTy = llvm::cast<ShapedType>(getInput().getType());
754  auto outTy = llvm::cast<ShapedType>(getType());
755  auto inETy = inTy.getElementType();
756  auto outETy = outTy.getElementType();
757 
758  if (operand.isSplat()) {
759  if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
760  bool overflow;
761  auto splatVal = operand.getSplatValue<APFloat>();
762  auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
763  splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
764  &overflow);
765  return SplatElementsAttr::get(outTy, splatVal);
766  }
767 
768  if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
769  auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
770  APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
771  splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
772  llvm::RoundingMode::NearestTiesToEven);
773  return SplatElementsAttr::get(outTy, splatVal);
774  }
775 
776  if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
777  auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
778  auto intVal = APSInt(
779  llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
780  auto floatVal = operand.getSplatValue<APFloat>();
781  bool exact;
782  floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
783  &exact);
784  return SplatElementsAttr::get(outTy, intVal);
785  }
786 
787  if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
788  auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
789  bool trunc =
790  inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
791  auto intVal = operand.getSplatValue<APInt>();
792  auto bitwidth = outETy.getIntOrFloatBitWidth();
793 
794  if (trunc) {
795  intVal = intVal.trunc(bitwidth);
796  } else if (unsignIn) {
797  intVal = intVal.zext(bitwidth);
798  } else {
799  intVal = intVal.sext(bitwidth);
800  }
801 
802  return SplatElementsAttr::get(outTy, intVal);
803  }
804  }
805 
806  return {};
807 }
808 
809 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
810 
811 #define REDUCE_FOLDER(OP) \
812  OpFoldResult OP::fold(FoldAdaptor adaptor) { \
813  ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
814  if (!inputTy.hasRank()) \
815  return {}; \
816  if (inputTy != getType()) \
817  return {}; \
818  if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
819  return getInput(); \
820  return {}; \
821  }
822 
823 REDUCE_FOLDER(ReduceAllOp)
824 REDUCE_FOLDER(ReduceAnyOp)
825 REDUCE_FOLDER(ReduceMaxOp)
826 REDUCE_FOLDER(ReduceMinOp)
827 REDUCE_FOLDER(ReduceProdOp)
828 REDUCE_FOLDER(ReduceSumOp)
829 #undef REDUCE_FOLDER
830 
831 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
832  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
833  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
834 
835  if (!inputTy || !outputTy)
836  return {};
837 
838  // Fold when the input and output types are the same. This is only safe when
839  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
840  // there may still be a productive reshape.
841  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
842  return getInput1();
843 
844  // reshape(reshape(x)) -> reshape(x)
845  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
846  getInput1().getDefiningOp())) {
847  getInput1Mutable().assign(reshapeOp.getInput1());
848  return getResult();
849  }
850 
851  // Cannot create an ElementsAttr from non-int/float/index types
852  if (!inputTy.getElementType().isIntOrIndexOrFloat())
853  return {};
854 
855  // reshape(const(x)) -> const(reshape-attr(x))
856  if (auto operand =
857  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
858  // Constants must have static shape.
859  if (!outputTy.hasStaticShape())
860  return {};
861 
862  // Okay to duplicate splat constants.
863  if (operand.isSplat())
864  return SplatElementsAttr::get(outputTy,
865  operand.getSplatValue<Attribute>());
866 
867  // Don't duplicate other constants.
868  if (!getInput1().hasOneUse())
869  return {};
870 
871  return operand.reshape(
872  llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
873  }
874 
875  return {};
876 }
877 
878 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
879  // If the pad is all zeros we can fold this operation away.
880  if (adaptor.getPadding() && getInput1().getType() == getType()) {
881  auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
882  if (densePad && densePad.isSplat() &&
883  densePad.getSplatValue<APInt>().isZero()) {
884  return getInput1();
885  }
886  }
887 
888  return {};
889 }
890 
891 // Fold away cases where a tosa.resize operation returns a copy
892 // of the input image.
893 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
894  ArrayRef<int64_t> offset = getOffset();
895  ArrayRef<int64_t> border = getBorder();
896  ArrayRef<int64_t> scale = getScale();
897 
898  // Check unit scaling.
899  if (scale[0] != scale[1] || scale[2] != scale[3]) {
900  return {};
901  }
902 
903  // There should be no offset.
904  if (offset[0] != 0 || offset[1] != 0) {
905  return {};
906  }
907 
908  // There should be no border.
909  if (border[0] != 0 || border[1] != 0) {
910  return {};
911  }
912 
913  auto input = getInput();
914  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
915  auto resultTy = llvm::cast<RankedTensorType>(getType());
916  if (inputTy != resultTy)
917  return {};
918 
919  return input;
920 }
921 
922 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
923  auto operand = getInput1();
924  auto operandTy = llvm::cast<ShapedType>(operand.getType());
925  auto axis = getAxis();
926  auto operandAttr =
927  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
928  if (operandAttr)
929  return operandAttr;
930 
931  // If the dim-length is 1, tosa.reverse is a no-op.
932  if (operandTy.hasRank() &&
933  (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
934  return operand;
935 
936  return {};
937 }
938 
939 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
940  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
941  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
942 
943  if (!inputTy || !outputTy)
944  return {};
945 
946  if (inputTy == outputTy && inputTy.hasStaticShape())
947  return getInput1();
948 
949  if (!adaptor.getInput1())
950  return {};
951 
952  // Cannot create an ElementsAttr from non-int/float/index types
953  if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
954  !outputTy.getElementType().isIntOrIndexOrFloat())
955  return {};
956 
957  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
958  if (operand.isSplat() && outputTy.hasStaticShape()) {
959  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
960  }
961 
962  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
963  outputTy.getNumElements() == 1) {
964  llvm::SmallVector<uint64_t> indices(getStart());
965  auto value = operand.getValues<Attribute>()[indices];
966  return SplatElementsAttr::get(outputTy, value);
967  }
968 
969  return {};
970 }
971 
972 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
973  if (getOnTrue() == getOnFalse())
974  return getOnTrue();
975 
976  auto predicate =
977  llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
978  if (!predicate)
979  return {};
980 
981  if (!predicate.isSplat())
982  return {};
983  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
984  : getOnFalse();
985 }
986 
987 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
988  bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
989  if (allOnes && getInput1().getType() == getType())
990  return getInput1();
991  return {};
992 }
993 
994 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
995  auto resultTy = llvm::cast<ShapedType>(getType());
996 
997  // Transposing splat values just means reshaping.
998  if (auto input =
999  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1000  if (input.isSplat() && resultTy.hasStaticShape() &&
1001  input.getType().getElementType() == resultTy.getElementType())
1002  return input.reshape(resultTy);
1003  }
1004 
1005  // Transpose does not change the input type.
1006  if (getInput1().getType() != getType())
1007  return {};
1008 
1009  // Transpose is not the identity transpose.
1010  SmallVector<int32_t> perms;
1011  if (getConstantPerms(perms).failed())
1012  return {};
1013 
1014  if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1015  return {};
1016 
1017  return getInput1();
1018 }
1019 
1020 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
1021  auto input = getInput1();
1022  // Element-wise log(exp(x)) = x
1023  if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
1024  return op.getInput1();
1025  }
1026 
1027  return {};
1028 }
1029 
1030 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
1031  auto input = getInput1();
1032  // Element-wise exp(log(x)) = x
1033  if (auto op = input.getDefiningOp<tosa::LogOp>()) {
1034  return op.getInput1();
1035  }
1036 
1037  return {};
1038 }
1039 
1040 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1041  auto input = getInput1();
1042  // Element-wise negate(negate(x)) = x
1043  if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
1044  return op.getInput1();
1045  }
1046 
1047  return {};
1048 }
1049 
1050 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1051  auto input = getInput1();
1052  // Element-wise abs(abs(x)) = abs(x)
1053  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1054  return input;
1055  }
1056 
1057  return {};
1058 }
1059 
1060 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1061  // Fold consecutive concats on the same axis into a single op.
1062  // Keep track of the operands so we are able to construct a new concat
1063  // later. Conservatively assume that we double the number of operands when
1064  // folding
1065  SmallVector<Value, 8> concatOperands;
1066  concatOperands.reserve(2 * getNumOperands());
1067 
1068  // Find all operands that are foldable concats
1069  bool foundFoldableConcat = false;
1070  for (Value operand : getOperands()) {
1071  concatOperands.emplace_back(operand);
1072 
1073  auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1074  if (!producer)
1075  continue;
1076 
1077  // Not foldable if axes are not the same
1078  if (getAxis() != producer.getAxis())
1079  continue;
1080 
1081  // Replace the original operand with all incoming operands
1082  foundFoldableConcat = true;
1083  concatOperands.pop_back();
1084  llvm::append_range(concatOperands, producer->getOperands());
1085  }
1086 
1087  if (!foundFoldableConcat)
1088  return {};
1089 
1090  getOperation()->setOperands(concatOperands);
1091  return getResult();
1092 }
1093 
1094 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1095  auto input = adaptor.getInput1();
1096 
1097  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1098  // Fold splat inputs only.
1099  if (!inputAttr || !inputAttr.isSplat())
1100  return {};
1101 
1102  auto shapeType = llvm::cast<ShapedType>(getType());
1103  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1104  auto floatVal = inputAttr.getSplatValue<APFloat>();
1105  return DenseElementsAttr::get(shapeType,
1106  ReciprocalOp::calcOneElement(floatVal));
1107  }
1108 
1109  return {};
1110 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:286
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362