MLIR  20.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // TOSA canonicalization patterns and folders.
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <functional>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 //===----------------------------------------------------------------------===//
39 // Operator Canonicalizers.
40 //===----------------------------------------------------------------------===//
41 
42 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
44 
45  LogicalResult matchAndRewrite(tosa::ConcatOp op,
46  PatternRewriter &rewriter) const override {
47  if (op.getInput1().size() != 1)
48  return failure();
49  if (op.getInput1().front().getType() != op.getType()) {
50  rewriter
51  .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
52  op.getInput1().front())
53  .getResult();
54  return success();
55  }
56 
57  rewriter.replaceOp(op, op.getInput1().front());
58  return success();
59  }
60 };
61 
62 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
63  MLIRContext *context) {
64  results.add<ConcatOptimization>(context);
65 }
66 
67 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
68  auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
69  if (!notOp)
70  return failure();
71  rewriter.modifyOpInPlace(op, [&]() {
72  op.getOperation()->setOperands(
73  {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
74  });
75  return success();
76 }
77 
79  : public OpRewritePattern<tosa::TransposeOp> {
81 
82  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
83  PatternRewriter &rewriter) const override {
84  // Input is also TransposeOp - transpose(transpose(A)).
85  auto innerTranspose =
86  transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
87  if (!innerTranspose)
88  return rewriter.notifyMatchFailure(transposeOp,
89  "input must be transpose operation");
90 
91  SmallVector<int64_t> transposePerms, innerTransposePerms;
92  if (transposeOp.getConstantPerms(transposePerms).failed())
93  return rewriter.notifyMatchFailure(transposeOp,
94  "transpose perms must be constant");
95  if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
96  return rewriter.notifyMatchFailure(
97  transposeOp, "inner transpose perms must be constant");
98  if (transposePerms.size() != innerTransposePerms.size())
99  return rewriter.notifyMatchFailure(
100  transposeOp,
101  "transpose and inner transpose perms sizes must be equal");
102  if (transposePerms.empty())
103  return rewriter.notifyMatchFailure(
104  transposeOp, "transpose perms sizes must be positive");
105 
106  // Consolidate transposes into one transpose.
107  SmallVector<int32_t> perms(transposePerms.size());
108  for (int i = 0, s = transposePerms.size(); i < s; ++i)
109  perms[i] = innerTransposePerms[transposePerms[i]];
110 
111  auto permsTy =
112  RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
113  auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
114  Value permsValue =
115  rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
116 
117  rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
118  transposeOp, transposeOp.getResult().getType(),
119  innerTranspose.getInput1(), permsValue);
120 
121  return success();
122  }
123 };
124 
125 // Determines the case when tosa.transpose is a tosa.reshape operation.
126 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
128 
129  LogicalResult matchAndRewrite(tosa::TransposeOp op,
130  PatternRewriter &rewriter) const override {
131  DenseIntElementsAttr permAttr;
132  if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
133  return rewriter.notifyMatchFailure(op, "Non-constant permutation");
134 
135  if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
136  return rewriter.notifyMatchFailure(
137  op, "Src is from transpose, can compose transposes");
138 
139  Value result = op.getResult();
140  for (Operation *subop : result.getUsers()) {
141  if (dyn_cast_or_null<tosa::TransposeOp>(subop))
142  return rewriter.notifyMatchFailure(
143  op, "Dest is used by transpose, can compose transposes");
144  }
145 
146  auto input = op.getInput1();
147  auto inputTy = llvm::cast<ShapedType>(input.getType());
148  if (!inputTy.hasRank())
149  return rewriter.notifyMatchFailure(op, "Unranked input.");
150 
151  int64_t numDynDims = 0;
152  for (int i = 0; i < inputTy.getRank(); ++i)
153  if (inputTy.isDynamicDim(i))
154  numDynDims++;
155 
156  if (numDynDims > 1)
157  return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
158 
159  SmallVector<int64_t> permValues = llvm::to_vector<6>(
160  llvm::map_range(permAttr.getValues<APInt>(),
161  [](const APInt &val) { return val.getSExtValue(); }));
162 
163  SmallVector<int64_t> nonZeroPerms;
164  nonZeroPerms.reserve(permValues.size());
165  for (auto idx : permValues) {
166  auto sz = inputTy.getDimSize(idx);
167  if (sz != 1)
168  nonZeroPerms.push_back(idx);
169  }
170 
171  for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
172  if (nonZeroPerms[i - 1] > nonZeroPerms[i])
173  return rewriter.notifyMatchFailure(op,
174  "Transpose changes memory layout.");
175 
176  SmallVector<int64_t> newShape;
177  newShape.reserve(inputTy.getRank());
178  for (int i = 0, s = inputTy.getRank(); i < s; ++i)
179  newShape.push_back(inputTy.getDimSize(permValues[i]));
180 
181  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
182  op, op.getType(), op.getInput1(),
183  rewriter.getDenseI64ArrayAttr(newShape));
184  return success();
185  }
186 };
187 
188 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
189  MLIRContext *context) {
191 }
192 
193 struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
195 
196  LogicalResult matchAndRewrite(tosa::PadOp op,
197  PatternRewriter &rewriter) const override {
198  if (op.getPadConst())
199  return failure();
200 
201  auto input = op.getInput1();
202  auto padding = op.getPadding();
203 
204  ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
205  Type elementTy = inputTy.getElementType();
206 
207  Attribute constantAttr;
208  if (llvm::isa<FloatType>(elementTy)) {
209  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
210  } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
211  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
212  } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
213  auto value = op.getQuantizationInfo()->getInputZp();
214  constantAttr = rewriter.getIntegerAttr(elementTy, value);
215  }
216 
217  if (!constantAttr) {
218  return rewriter.notifyMatchFailure(
219  op,
220  "tosa.pad to linalg lowering encountered an unknown element type");
221  }
222 
223  auto denseAttr = DenseElementsAttr::get(
224  RankedTensorType::get({}, elementTy), constantAttr);
225  auto constantVal = rewriter.create<tosa::ConstOp>(
226  op.getLoc(), denseAttr.getType(), denseAttr);
227 
228  rewriter.replaceOpWithNewOp<tosa::PadOp>(
229  op, op.getType(), ValueRange{input, padding, constantVal},
230  op->getAttrs());
231  return success();
232  }
233 };
234 
235 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
236  MLIRContext *context) {
237  results.add<MaterializePadValue>(context);
238 }
239 
240 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
242 
243  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
244  PatternRewriter &rewriter) const override {
245  Value input = op.getInput();
246  Value output = op.getOutput();
247  ShapedType inputType = llvm::cast<ShapedType>(input.getType());
248  ShapedType outputType = llvm::cast<ShapedType>(output.getType());
249 
250  if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
251  return failure();
252  }
253 
254  // If the output and input shapes are 1x1, then this is a no op.
255  ArrayRef<int64_t> outputShape = outputType.getShape();
256  if (outputShape[1] != 1 || outputShape[2] != 1) {
257  return failure();
258  }
259 
260  ArrayRef<int64_t> inputShape = inputType.getShape();
261  if (inputShape[1] != 1 || inputShape[2] != 1) {
262  return failure();
263  }
264 
265  rewriter.replaceOp(op, input);
266  return success();
267  }
268 };
269 
270 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
271  MLIRContext *context) {
272  results.add<MaxPool2dIsNoOp>(context);
273 }
274 
275 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
277 
278  LogicalResult matchAndRewrite(tosa::ClampOp op,
279  PatternRewriter &rewriter) const override {
280  Value input = op.getInput();
281  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
282  auto inputElementType = inputType.getElementType();
283 
284  if (!inputType.hasStaticShape()) {
285  return failure();
286  }
287 
288  if (isa<FloatType>(inputElementType)) {
289  // Unlike integer types, floating point types can represent infinity.
290  auto minClamp = op.getMinFp();
291  auto maxClamp = op.getMaxFp();
292  bool isMin = minClamp.isInfinity() && minClamp.isNegative();
293  bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
294 
295  if (isMin && isMax) {
296  rewriter.replaceOp(op, input);
297  return success();
298  }
299  return failure();
300  }
301 
302  if (inputElementType.isUnsignedInteger()) {
303  int64_t minClamp = op.getMinInt();
304  int64_t maxClamp = op.getMaxInt();
305 
306  int64_t intMin =
307  APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
308  .getZExtValue();
309  int64_t intMax =
310  APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
311  .getZExtValue();
312 
313  if (minClamp <= intMin && maxClamp >= intMax) {
314  rewriter.replaceOp(op, input);
315  return success();
316  }
317  return failure();
318  }
319 
320  if (llvm::isa<IntegerType>(inputElementType)) {
321  int64_t minClamp = op.getMinInt();
322  int64_t maxClamp = op.getMaxInt();
323 
324  int64_t intMin =
325  APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
326  .getSExtValue();
327  int64_t intMax =
328  APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
329  .getSExtValue();
330 
331  if (minClamp <= intMin && maxClamp >= intMax) {
332  rewriter.replaceOp(op, input);
333  return success();
334  }
335  return failure();
336  }
337 
338  return failure();
339  }
340 };
341 
342 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
344 
345  LogicalResult matchAndRewrite(tosa::ClampOp op,
346  PatternRewriter &rewriter) const override {
347  Value input = op.getInput();
348 
349  Operation *definingOp = input.getDefiningOp();
350  if (!definingOp)
351  return failure();
352 
353  if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354  auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
355  auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
356 
357  auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
358  auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
359 
360  rewriter.replaceOpWithNewOp<tosa::ClampOp>(
361  op, op.getType(), clampOp.getInput(),
362  rewriter.getI64IntegerAttr(minInt),
363  rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
364  rewriter.getF32FloatAttr(maxFp));
365  return success();
366  }
367 
368  return failure();
369  }
370 };
371 
372 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
373  MLIRContext *context) {
374  results.add<ClampIsNoOp>(context);
375  results.add<ClampClampOptimization>(context);
376 }
377 
378 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
380 
381  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
382  PatternRewriter &rewriter) const override {
383  Value sliceInput = sliceOp.getInput();
384  auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
385  if (!concatOp)
386  return rewriter.notifyMatchFailure(
387  sliceOp, "slice input must be concat operation");
388 
389  OperandRange inputs = concatOp.getInput1();
390  auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
391  if (!concatType || !concatType.hasStaticShape())
392  return rewriter.notifyMatchFailure(
393  sliceOp, "slice input must be a static ranked tensor");
394  int32_t axis = concatOp.getAxis();
395 
396  llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
397  llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
398 
399  // Validate slice on the concatenated axis. Slicing along this
400  // axis should span only one of the inputs to the concatenate
401  // operation.
402  std::optional<Value> replaceWithSlice;
403  for (auto input : inputs) {
404  auto inputType = dyn_cast<RankedTensorType>(input.getType());
405  if (!inputType || !inputType.hasStaticShape())
406  return rewriter.notifyMatchFailure(
407  sliceOp, "concat input must be a static ranked tensor");
408 
409  if (sliceStart[axis] >= 0 &&
410  (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
411  replaceWithSlice = rewriter
412  .create<tosa::SliceOp>(
413  sliceOp.getLoc(), sliceOp.getType(), input,
414  rewriter.getDenseI64ArrayAttr(sliceStart),
415  rewriter.getDenseI64ArrayAttr(sliceSize))
416  .getResult();
417  break;
418  }
419  sliceStart[axis] -= inputType.getDimSize(axis);
420  }
421 
422  if (!replaceWithSlice)
423  return rewriter.notifyMatchFailure(
424  sliceOp, "corresponding concat input not found for slice");
425 
426  rewriter.replaceOp(sliceOp, replaceWithSlice.value());
427  return success();
428  }
429 };
430 
431 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
432  MLIRContext *context) {
433  results.add<ConcatSliceOptimization>(context);
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Operator Folders.
438 //===----------------------------------------------------------------------===//
439 
440 template <typename IntFolder, typename FloatFolder>
442  RankedTensorType returnTy) {
443  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
444  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
445  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
446  if (lETy != rETy)
447  return {};
448 
449  if (llvm::isa<IntegerType>(lETy)) {
450  APInt l = lhs.getSplatValue<APInt>();
451  APInt r = rhs.getSplatValue<APInt>();
452  auto result = IntFolder()(l, r);
453  return DenseElementsAttr::get(returnTy, result);
454  }
455 
456  if (llvm::isa<FloatType>(lETy)) {
457  APFloat l = lhs.getSplatValue<APFloat>();
458  APFloat r = rhs.getSplatValue<APFloat>();
459  auto result = FloatFolder()(l, r);
460  return DenseElementsAttr::get(returnTy, result);
461  }
462  }
463 
464  return {};
465 }
466 
467 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
468  if (llvm::isa<FloatType>(elemType))
469  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
470  if (llvm::isa<IntegerType>(elemType))
471  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
472  return false;
473 }
474 
475 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
476  if (llvm::isa<FloatType>(elemType))
477  return val && val.isSplat() &&
478  val.getSplatValue<APFloat>().isExactlyValue(1.0);
479  if (llvm::isa<IntegerType>(elemType)) {
480  const int64_t shifted = 1LL << shift;
481  return val && val.isSplat() &&
482  val.getSplatValue<APInt>().getSExtValue() == shifted;
483  }
484  return false;
485 }
486 
487 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
488  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
489  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
490  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
491  if (!lhsTy || !rhsTy || !resultTy)
492  return {};
493 
494  // Cannot create an ElementsAttr from non-int/float/index types
495  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
496  !rhsTy.getElementType().isIntOrIndexOrFloat())
497  return {};
498 
499  auto resultETy = resultTy.getElementType();
500  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
501  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
502 
503  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
504  return getInput1();
505  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
506  return getInput2();
507 
508  if (!lhsAttr || !rhsAttr)
509  return {};
510 
511  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
512  resultTy);
513 }
514 
515 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
516  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
517  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
518  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
519  !outputTy.hasStaticShape())
520  return {};
521 
522  if (inputTy.getDimSize(getAxis()) == 1)
523  return DenseElementsAttr::get(outputTy, 0);
524 
525  return {};
526 }
527 
528 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
529  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
530  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
531  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
532  if (!lhsTy || !rhsTy || !resultTy)
533  return {};
534  if (lhsTy != rhsTy)
535  return {};
536 
537  // IntDivOp inputs must be integer type, no need to check for quantized type
538  auto resultETy = resultTy.getElementType();
539  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
540  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
541  if (lhsAttr && lhsAttr.isSplat()) {
542  if (llvm::isa<IntegerType>(resultETy) &&
543  lhsAttr.getSplatValue<APInt>().isZero())
544  return lhsAttr;
545  }
546 
547  if (rhsAttr && rhsAttr.isSplat()) {
548  if (llvm::isa<IntegerType>(resultETy) &&
549  rhsAttr.getSplatValue<APInt>().isOne())
550  return getInput1();
551  }
552 
553  if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
554  if (llvm::isa<IntegerType>(resultETy)) {
555  APInt l = lhsAttr.getSplatValue<APInt>();
556  APInt r = rhsAttr.getSplatValue<APInt>();
557  APInt result = l.sdiv(r);
558  return DenseElementsAttr::get(resultTy, result);
559  }
560  }
561 
562  return {};
563 }
564 
565 namespace {
566 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
567  RankedTensorType ty, int32_t shift) {
568  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
569  if (llvm::isa<IntegerType>(ty.getElementType())) {
570  APInt l = lhs.getSplatValue<APInt>();
571  APInt r = rhs.getSplatValue<APInt>();
572 
573  if (shift == 0) {
574  return DenseElementsAttr::get(ty, l * r);
575  }
576 
577  auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
578  l = l.sext(bitwidth * 2);
579  r = r.sext(bitwidth * 2);
580  auto result = l * r;
581  result.lshrInPlace(shift);
582  result = result.trunc(bitwidth);
583  return DenseElementsAttr::get(ty, result);
584  }
585 
586  if (llvm::isa<FloatType>(ty.getElementType())) {
587  APFloat l = lhs.getSplatValue<APFloat>();
588  APFloat r = rhs.getSplatValue<APFloat>();
589  APFloat result = l * r;
590  return DenseElementsAttr::get(ty, result);
591  }
592  }
593 
594  return {};
595 }
596 } // namespace
597 
598 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
599  auto lhs = getInput1();
600  auto rhs = getInput2();
601  auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
602  auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
603  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
604  if (!lhsTy || !rhsTy || !resultTy)
605  return {};
606 
607  auto resultETy = resultTy.getElementType();
608  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
609  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
610 
611  const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
612  if (rhsTy == resultTy) {
613  if (isSplatZero(resultETy, lhsAttr))
614  return lhsAttr.resizeSplat(resultTy);
615  if (isSplatOne(resultETy, lhsAttr, shift))
616  return rhs;
617  }
618  if (lhsTy == resultTy) {
619  if (isSplatZero(resultETy, rhsAttr))
620  return rhsAttr.resizeSplat(resultTy);
621  if (isSplatOne(resultETy, rhsAttr, shift))
622  return lhs;
623  }
624 
625  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
626 }
627 
628 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
629  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
630  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
631  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
632  if (!lhsTy || !rhsTy || !resultTy)
633  return {};
634 
635  // Cannot create an ElementsAttr from non-int/float/index types
636  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
637  !rhsTy.getElementType().isIntOrIndexOrFloat())
638  return {};
639 
640  auto resultETy = resultTy.getElementType();
641  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
642  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
643 
644  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
645  return getInput1();
646 
647  if (!lhsAttr || !rhsAttr)
648  return {};
649 
650  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
651  resultTy);
652 }
653 
654 namespace {
655 template <typename Cmp>
656 struct ComparisonFold {
657  ComparisonFold() = default;
658  APInt operator()(const APInt &l, const APInt &r) {
659  return APInt(1, Cmp()(l, r));
660  }
661 
662  APInt operator()(const APFloat &l, const APFloat &r) {
663  return APInt(1, Cmp()(l, r));
664  }
665 };
666 
667 struct APIntFoldGreater {
668  APIntFoldGreater() = default;
669  APInt operator()(const APInt &l, const APInt &r) {
670  return APInt(1, l.sgt(r));
671  }
672 };
673 
674 struct APIntFoldGreaterEqual {
675  APIntFoldGreaterEqual() = default;
676  APInt operator()(const APInt &l, const APInt &r) {
677  return APInt(1, l.sge(r));
678  }
679 };
680 } // namespace
681 
682 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
683  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
684  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
685  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
686 
687  if (!lhsAttr || !rhsAttr)
688  return {};
689 
690  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
691  lhsAttr, rhsAttr, resultTy);
692 }
693 
694 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
695  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
696  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
697  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
698 
699  if (!lhsAttr || !rhsAttr)
700  return {};
701 
702  return binaryFolder<APIntFoldGreaterEqual,
703  ComparisonFold<std::greater_equal<APFloat>>>(
704  lhsAttr, rhsAttr, resultTy);
705 }
706 
707 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
708  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
709  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
710  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
711  Value lhs = getInput1();
712  Value rhs = getInput2();
713  auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
714 
715  // If we are comparing an integer value to itself it is always true. We can
716  // not do this with float due to float values.
717  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
718  resultTy.hasStaticShape() && lhs == rhs) {
719  return DenseElementsAttr::get(resultTy, true);
720  }
721 
722  if (!lhsAttr || !rhsAttr)
723  return {};
724 
725  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
726  ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
727  resultTy);
728 }
729 
730 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
731  if (getInput().getType() == getType())
732  return getInput();
733 
734  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
735  if (!operand)
736  return {};
737 
738  auto inTy = llvm::cast<ShapedType>(getInput().getType());
739  auto outTy = llvm::cast<ShapedType>(getType());
740  auto inETy = inTy.getElementType();
741  auto outETy = outTy.getElementType();
742 
743  if (operand.isSplat()) {
744  if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
745  bool overflow;
746  auto splatVal = operand.getSplatValue<APFloat>();
747  auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
748  splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
749  &overflow);
750  return SplatElementsAttr::get(outTy, splatVal);
751  }
752 
753  if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
754  auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
755  APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
756  splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
757  llvm::RoundingMode::NearestTiesToEven);
758  return SplatElementsAttr::get(outTy, splatVal);
759  }
760 
761  if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
762  auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
763  auto intVal = APSInt(
764  llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
765  auto floatVal = operand.getSplatValue<APFloat>();
766  bool exact;
767  floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
768  &exact);
769  return SplatElementsAttr::get(outTy, intVal);
770  }
771 
772  if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
773  auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
774  bool trunc =
775  inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
776  auto intVal = operand.getSplatValue<APInt>();
777  auto bitwidth = outETy.getIntOrFloatBitWidth();
778 
779  if (trunc) {
780  intVal = intVal.trunc(bitwidth);
781  } else if (unsignIn) {
782  intVal = intVal.zext(bitwidth);
783  } else {
784  intVal = intVal.sext(bitwidth);
785  }
786 
787  return SplatElementsAttr::get(outTy, intVal);
788  }
789  }
790 
791  return {};
792 }
793 
794 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
795 
796 #define REDUCE_FOLDER(OP) \
797  OpFoldResult OP::fold(FoldAdaptor adaptor) { \
798  ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
799  if (!inputTy.hasRank()) \
800  return {}; \
801  if (inputTy != getType()) \
802  return {}; \
803  if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
804  return getInput(); \
805  return {}; \
806  }
807 
808 REDUCE_FOLDER(ReduceAllOp)
809 REDUCE_FOLDER(ReduceAnyOp)
810 REDUCE_FOLDER(ReduceMaxOp)
811 REDUCE_FOLDER(ReduceMinOp)
812 REDUCE_FOLDER(ReduceProdOp)
813 REDUCE_FOLDER(ReduceSumOp)
814 #undef REDUCE_FOLDER
815 
816 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
817  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
818  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
819 
820  if (!inputTy || !outputTy)
821  return {};
822 
823  // Fold when the input and output types are the same. This is only safe when
824  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
825  // there may still be a productive reshape.
826  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
827  return getInput1();
828 
829  // reshape(reshape(x)) -> reshape(x)
830  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
831  getInput1().getDefiningOp())) {
832  getInput1Mutable().assign(reshapeOp.getInput1());
833  return getResult();
834  }
835 
836  // Cannot create an ElementsAttr from non-int/float/index types
837  if (!inputTy.getElementType().isIntOrIndexOrFloat())
838  return {};
839 
840  // reshape(const(x)) -> const(reshape-attr(x))
841  if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
842  // Constants must have static shape.
843  if (!outputTy.hasStaticShape())
844  return {};
845 
846  // Okay to duplicate splat constants.
847  if (operand.isSplat())
848  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
849 
850  // Don't duplicate other constants.
851  if (!getInput1().hasOneUse())
852  return {};
853 
854  return operand.reshape(
855  llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
856  }
857 
858  return {};
859 }
860 
861 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
862  // If the pad is all zeros we can fold this operation away.
863  if (adaptor.getPadding() && getInput1().getType() == getType()) {
864  auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
865  if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
866  return getInput1();
867  }
868  }
869 
870  return {};
871 }
872 
873 // Fold away cases where a tosa.resize operation returns a copy
874 // of the input image.
875 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
876  ArrayRef<int64_t> offset = getOffset();
877  ArrayRef<int64_t> border = getBorder();
878  ArrayRef<int64_t> scale = getScale();
879 
880  // Check unit scaling.
881  if (scale[0] != scale[1] || scale[2] != scale[3]) {
882  return {};
883  }
884 
885  // There should be no offset.
886  if (offset[0] != 0 || offset[1] != 0) {
887  return {};
888  }
889 
890  // There should be no border.
891  if (border[0] != 0 || border[1] != 0) {
892  return {};
893  }
894 
895  auto input = getInput();
896  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
897  auto resultTy = llvm::cast<RankedTensorType>(getType());
898  if (inputTy != resultTy)
899  return {};
900 
901  return input;
902 }
903 
904 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
905  auto operand = getInput();
906  auto operandTy = llvm::cast<ShapedType>(operand.getType());
907  auto axis = getAxis();
908  auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
909  if (operandAttr)
910  return operandAttr;
911 
912  // If the dim-length is 1, tosa.reverse is a no-op.
913  if (operandTy.hasRank() &&
914  (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
915  return operand;
916 
917  return {};
918 }
919 
920 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
921  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
922  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
923 
924  if (!inputTy || !outputTy)
925  return {};
926 
927  if (inputTy == outputTy && inputTy.hasStaticShape())
928  return getInput();
929 
930  if (!adaptor.getInput())
931  return {};
932 
933  // Cannot create an ElementsAttr from non-int/float/index types
934  if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
935  !outputTy.getElementType().isIntOrIndexOrFloat())
936  return {};
937 
938  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
939  if (operand.isSplat() && outputTy.hasStaticShape()) {
940  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
941  }
942 
943  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
944  outputTy.getNumElements() == 1) {
945  llvm::SmallVector<uint64_t> indices(getStart());
946  auto value = operand.getValues<Attribute>()[indices];
947  return SplatElementsAttr::get(outputTy, value);
948  }
949 
950  return {};
951 }
952 
953 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
954  if (getOnTrue() == getOnFalse())
955  return getOnTrue();
956 
957  auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
958  if (!predicate)
959  return {};
960 
961  if (!predicate.isSplat())
962  return {};
963  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
964  : getOnFalse();
965 }
966 
967 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
968  bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
969  if (allOnes && getInput1().getType() == getType())
970  return getInput1();
971  return {};
972 }
973 
974 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
975  auto resultTy = llvm::cast<ShapedType>(getType());
976 
977  // Transposing splat values just means reshaping.
978  if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
979  if (input.isSplat() && resultTy.hasStaticShape() &&
980  input.getType().getElementType() == resultTy.getElementType())
981  return input.reshape(resultTy);
982  }
983 
984  // Transpose does not change the input type.
985  if (getInput1().getType() != getType())
986  return {};
987 
988  // Transpose is not the identity transpose.
989  SmallVector<int64_t> perms;
990  if (getConstantPerms(perms).failed())
991  return {};
992 
993  if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
994  return {};
995 
996  return getInput1();
997 }
998 
999 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
1000  auto input = getInput1();
1001  // Element-wise log(exp(x)) = x
1002  if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
1003  return op.getInput1();
1004  }
1005 
1006  return {};
1007 }
1008 
1009 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
1010  auto input = getInput1();
1011  // Element-wise exp(log(x)) = x
1012  if (auto op = input.getDefiningOp<tosa::LogOp>()) {
1013  return op.getInput1();
1014  }
1015 
1016  return {};
1017 }
1018 
1019 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1020  auto input = getInput1();
1021  // Element-wise negate(negate(x)) = x
1022  if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
1023  return op.getInput1();
1024  }
1025 
1026  return {};
1027 }
1028 
1029 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1030  auto input = getInput1();
1031  // Element-wise abs(abs(x)) = abs(x)
1032  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1033  return input;
1034  }
1035 
1036  return {};
1037 }
1038 
1039 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1040  // Fold consecutive concats on the same axis into a single op.
1041  // Keep track of the operands so we are able to construct a new concat
1042  // later. Conservatively assume that we double the number of operands when
1043  // folding
1044  SmallVector<Value, 8> concatOperands;
1045  concatOperands.reserve(2 * getNumOperands());
1046 
1047  // Find all operands that are foldable concats
1048  bool foundFoldableConcat = false;
1049  for (Value operand : getOperands()) {
1050  concatOperands.emplace_back(operand);
1051 
1052  auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1053  if (!producer)
1054  continue;
1055 
1056  // Not foldable if axes are not the same
1057  if (getAxis() != producer.getAxis())
1058  continue;
1059 
1060  // Replace the original operand with all incoming operands
1061  foundFoldableConcat = true;
1062  concatOperands.pop_back();
1063  llvm::append_range(concatOperands, producer->getOperands());
1064  }
1065 
1066  if (!foundFoldableConcat)
1067  return {};
1068 
1069  getOperation()->setOperands(concatOperands);
1070  return getResult();
1071 }
1072 
1073 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1074  auto input = adaptor.getInput1();
1075 
1076  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1077  // Fold splat inputs only.
1078  if (!inputAttr || !inputAttr.isSplat())
1079  return {};
1080 
1081  auto shapeType = llvm::cast<ShapedType>(getType());
1082  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1083  auto floatVal = inputAttr.getSplatValue<APFloat>();
1084  return DenseElementsAttr::get(shapeType,
1085  ReciprocalOp::calcOneElement(floatVal));
1086  }
1087 
1088  return {};
1089 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:246
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:191
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:269
IntegerType getI32Type()
Definition: Builders.cpp:91
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:136
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:261
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:472
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362