MLIR  22.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // TOSA canonicalization patterns and folders.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/APInt.h"
26 
27 #include <functional>
28 
29 using namespace mlir;
30 using namespace mlir::tosa;
31 
32 //===----------------------------------------------------------------------===//
33 // Operator Canonicalizers.
34 //===----------------------------------------------------------------------===//
35 
36 //===----------------------------------------------------------------------===//
37 // Tensor Data Engine Operators.
38 //===----------------------------------------------------------------------===//
39 
40 // Check that the zero point of the tensor and padding operations are aligned.
41 static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
42  // Check that padConst is a constant value and a scalar tensor
43  DenseElementsAttr padConstAttr;
44  if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
45  (padConstAttr.size() != 1)) {
46  return false;
47  }
48 
49  // Check that floating point pad is zero
50  if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
51  float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
52  return padConstVal == 0.0f;
53  }
54 
55  // Check that the zp and padConst align for the integer (quantized) case
56  if (auto padConstIntAttr =
57  mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
58  DenseIntElementsAttr zpAttr;
59  // Check that zp is a constant value and a scalar tensor
60  if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
61  return false;
62  }
63 
64  // Check equality
65  int64_t zpVal = (*zpAttr.begin()).getSExtValue();
66  int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
67  return zpVal == padConstVal;
68  }
69 
70  // Bail-out on unsupported type
71  return false;
72 }
73 
74 namespace {
75 template <typename OpTy>
76 struct PoolPadFoldAdaptor;
77 
78 template <>
79 struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
80  using OpTy = tosa::MaxPool2dOp;
81  static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
82  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
83  if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
84  newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
85  return false;
86  return true;
87  }
88  static bool checkPadConstCompliance(OpTy, Value padConst) {
89  // Check that padConst is a constant value and a scalar tensor
90  DenseElementsAttr padConstAttr;
91  if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
92  padConstAttr.size() != 1) {
93  return false;
94  }
95 
96  // Pad needs to be in the minimum value to be able to merge
97  if (auto padConstFpAttr =
98  mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
99  const APFloat padConstVal = *padConstFpAttr.begin();
100  const APFloat lowestVal =
101  APFloat::getLargest(padConstVal.getSemantics(), true);
102  return padConstVal == lowestVal;
103  }
104  if (auto padConstIntAttr =
105  mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
106  const APInt padConstVal = *padConstIntAttr.begin();
107  const unsigned int bitWidth = padConstVal.getBitWidth();
108  const APInt lowestVal =
109  padConstIntAttr.getElementType().isUnsignedInteger()
110  ? APInt::getZero(bitWidth)
111  : APInt::getSignedMinValue(bitWidth);
112  return padConstVal == lowestVal;
113  }
114 
115  // Bail-out on unsupported type
116  return false;
117  }
118  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
119  Value padInput, ArrayRef<int64_t> newPad) {
120  rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
121  op, op.getType(), padInput, op.getKernel(), op.getStride(),
122  rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
123  }
124 };
125 
126 template <typename OpTy>
127 struct ConvPadFoldAdaptor {
128  static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
129  return true;
130  }
131  static bool checkPadConstCompliance(OpTy op, Value padConst) {
132  return checkMatchingPadConstAndZp(padConst, op.getInputZp());
133  }
134  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
135  Value padInput, ArrayRef<int64_t> newPad) {
136  rewriter.replaceOpWithNewOp<OpTy>(
137  op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
138  op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
139  op.getDilationAttr(), op.getAccType(), op.getLocalBound());
140  }
141 };
142 
143 // Pattern attempts to fold a `tosa.pad` operator to a following tensor
144 // operation like `tosa.conv2d` by merging the padding associated with the
145 // pad operator directly to the implicit padding of the tensor operation.
146 // This helps eliminate the explicit padding operator if unused.
147 template <typename OpTy, typename AdaptorTy>
148 struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
150 
151  LogicalResult matchAndRewrite(OpTy tensorOp,
152  PatternRewriter &rewriter) const override {
153  // Check producer is a tosa::PadOp
154  auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
155  if (!padOp)
156  return rewriter.notifyMatchFailure(tensorOp,
157  "Producer must be a tosa::PadOp.");
158 
159  // Validate that tensor operation has sane padding
160  const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
161  if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
162  return rewriter.notifyMatchFailure(
163  tensorOp, "Tensor operation padding shall have 4 elements.");
164 
165  // Validate tosa::PadOp padding
166  DenseIntElementsAttr padOpPadding;
167  if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
168  return rewriter.notifyMatchFailure(
169  tensorOp,
170  "The `padding` input specified on the tosa::PadOp must be constant.");
171  }
172  // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
173  // C_after
174  if (padOpPadding.size() != 8)
175  return rewriter.notifyMatchFailure(tensorOp,
176  "Pad padding should have 8 elements.");
177  int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
178  int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
179  int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
180  int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
181  int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
182  int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
183  int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
184  int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
185 
186  if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
187  return rewriter.notifyMatchFailure(
188  tensorOp, "Folding padding in N or C dimensions is not supported.");
189 
190  // Fold padding from Pad into the tensor operation
191  // 4 elements - pad_top, pad_bottom, pad_left, pad_right
192  SmallVector<int64_t> foldedPad(tensorOpPad.size());
193  foldedPad[0] = padHBefore + tensorOpPad[0];
194  foldedPad[1] = padHAfter + tensorOpPad[1];
195  foldedPad[2] = padWBefore + tensorOpPad[2];
196  foldedPad[3] = padWAfter + tensorOpPad[3];
197 
198  // Check kernel related restrictions
199  if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
200  return rewriter.notifyMatchFailure(
201  tensorOp, "Padding size not aligned with kernel restrictions.");
202  }
203 
204  // Check padding constant restrictions
205  if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
206  return rewriter.notifyMatchFailure(
207  tensorOp,
208  "Padding constant is not aligned with operator zero-point.");
209  }
210 
211  // Check that padding doesn't grow more than 8K level (8192) for now
212  if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
213  return rewriter.notifyMatchFailure(
214  tensorOp, "Padding size more than the 8K level limit.");
215  }
216 
217  // Create operator
218  AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
219  foldedPad);
220 
221  return success();
222  }
223 };
224 } // namespace
225 
226 void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
227  MLIRContext *context) {
228  results.add<
229  FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
230  context);
231 }
232 
233 void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
234  MLIRContext *context) {
235  results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
236  ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
237  context);
238 }
239 
240 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
242 
243  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
244  PatternRewriter &rewriter) const override {
245  Value input = op.getInput();
246  Value output = op.getOutput();
247  ShapedType inputType = llvm::cast<ShapedType>(input.getType());
248  ShapedType outputType = llvm::cast<ShapedType>(output.getType());
249 
250  if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
251  return failure();
252  }
253 
254  // If the output and input shapes are 1x1, then this is a no op.
255  ArrayRef<int64_t> outputShape = outputType.getShape();
256  if (outputShape[1] != 1 || outputShape[2] != 1) {
257  return failure();
258  }
259 
260  ArrayRef<int64_t> inputShape = inputType.getShape();
261  if (inputShape[1] != 1 || inputShape[2] != 1) {
262  return failure();
263  }
264 
265  rewriter.replaceOp(op, input);
266  return success();
267  }
268 };
269 
270 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
271  MLIRContext *context) {
272  results.add<MaxPool2dIsNoOp,
273  FoldPadToTensorOp<tosa::MaxPool2dOp,
274  PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
275  context);
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // Data Layout / Memory Reinterpretation.
280 //===----------------------------------------------------------------------===//
281 
282 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
284 
285  LogicalResult matchAndRewrite(tosa::ConcatOp op,
286  PatternRewriter &rewriter) const override {
287  if (op.getInput1().size() != 1)
288  return failure();
289  if (op.getInput1().front().getType() != op.getType()) {
290  rewriter
291  .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
292  op.getInput1().front())
293  .getResult();
294  return success();
295  }
296 
297  rewriter.replaceOp(op, op.getInput1().front());
298  return success();
299  }
300 };
301 
302 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
303  MLIRContext *context) {
304  results.add<ConcatOptimization>(context);
305 }
306 
307 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
308  auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
309  if (!notOp)
310  return failure();
311  rewriter.modifyOpInPlace(op, [&]() {
312  op.getOperation()->setOperands(
313  {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
314  });
315  return success();
316 }
317 
319  : public OpRewritePattern<tosa::TransposeOp> {
321 
322  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
323  PatternRewriter &rewriter) const override {
324  // Input is also TransposeOp - transpose(transpose(A)).
325  auto innerTranspose =
326  transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
327  if (!innerTranspose)
328  return rewriter.notifyMatchFailure(transposeOp,
329  "input must be transpose operation");
330 
331  const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
332  const llvm::ArrayRef<int32_t> innerTransposePerms =
333  innerTranspose.getPerms();
334 
335  if (transposePerms.size() != innerTransposePerms.size())
336  return rewriter.notifyMatchFailure(
337  transposeOp,
338  "transpose and inner transpose perms sizes must be equal");
339  if (transposePerms.empty())
340  return rewriter.notifyMatchFailure(
341  transposeOp, "transpose perms sizes must be positive");
342 
343  // Consolidate transposes into one transpose.
344  SmallVector<int32_t> perms(transposePerms.size());
345  for (int i = 0, s = transposePerms.size(); i < s; ++i)
346  perms[i] = innerTransposePerms[transposePerms[i]];
347 
348  rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
349  transposeOp, transposeOp.getResult().getType(),
350  innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
351 
352  return success();
353  }
354 };
355 
356 // Determines the case when tosa.transpose is a tosa.reshape operation.
357 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
359 
360  LogicalResult matchAndRewrite(tosa::TransposeOp op,
361  PatternRewriter &rewriter) const override {
362  if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
363  return rewriter.notifyMatchFailure(
364  op, "Src is from transpose, can compose transposes");
365 
366  Value result = op.getResult();
367  for (Operation *subop : result.getUsers()) {
368  if (isa_and_nonnull<tosa::TransposeOp>(subop))
369  return rewriter.notifyMatchFailure(
370  op, "Dest is used by transpose, can compose transposes");
371  }
372 
373  auto input = op.getInput1();
374  auto inputTy = llvm::cast<ShapedType>(input.getType());
375  if (!inputTy.hasRank())
376  return rewriter.notifyMatchFailure(op, "Unranked input.");
377 
378  int64_t numDynDims = 0;
379  for (int i = 0; i < inputTy.getRank(); ++i)
380  if (inputTy.isDynamicDim(i))
381  numDynDims++;
382 
383  if (numDynDims > 1)
384  return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
385 
386  const llvm::ArrayRef<int32_t> permValues = op.getPerms();
387 
388  SmallVector<int64_t> nonZeroPerms;
389  nonZeroPerms.reserve(permValues.size());
390  for (auto idx : permValues) {
391  auto sz = inputTy.getDimSize(idx);
392  if (sz != 1)
393  nonZeroPerms.push_back(idx);
394  }
395 
396  for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
397  if (nonZeroPerms[i - 1] > nonZeroPerms[i])
398  return rewriter.notifyMatchFailure(op,
399  "Transpose changes memory layout.");
400 
401  SmallVector<int64_t> newShape;
402  newShape.reserve(inputTy.getRank());
403  for (int i = 0, s = inputTy.getRank(); i < s; ++i)
404  newShape.push_back(inputTy.getDimSize(permValues[i]));
405 
406  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
407  op, op.getType(), op.getInput1(),
408  getTosaConstShape(rewriter, op.getLoc(), newShape));
409  return success();
410  }
411 };
412 
413 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
414  MLIRContext *context) {
416 }
417 
418 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
420 
421  LogicalResult matchAndRewrite(tosa::ClampOp op,
422  PatternRewriter &rewriter) const override {
423  Value input = op.getInput();
424  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
425  auto inputElementType = inputType.getElementType();
426 
427  if (isa<FloatType>(inputElementType)) {
428  // Unlike integer types, floating point types can represent infinity.
429  const auto minClamp =
430  llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
431  const auto maxClamp =
432  llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
433  const bool isMin = minClamp.isNegInfinity();
434  const bool isMax = maxClamp.isInfinity();
435 
436  if (isMin && isMax) {
437  rewriter.replaceOp(op, input);
438  return success();
439  }
440  return failure();
441  }
442 
443  // i1 types are boolean in TOSA
444  const bool isBoolean = inputElementType.isInteger(1);
445  if (inputElementType.isUnsignedInteger() || isBoolean) {
446  const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
447  .getValue()
448  .getZExtValue();
449  const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
450  .getValue()
451  .getZExtValue();
452 
453  const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
454  const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
455  const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
456 
457  if (minClamp <= intMin && maxClamp >= intMax) {
458  rewriter.replaceOp(op, input);
459  return success();
460  }
461  return failure();
462  }
463 
464  if (llvm::isa<IntegerType>(inputElementType)) {
465  const int64_t minClamp =
466  llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
467  const int64_t maxClamp =
468  llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
469 
470  const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
471  const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
472  const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
473 
474  if (minClamp <= intMin && maxClamp >= intMax) {
475  rewriter.replaceOp(op, input);
476  return success();
477  }
478  return failure();
479  }
480 
481  return failure();
482  }
483 };
484 
485 // Attempts the following transformation:
486 //
487 // For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
488 // tensor X the following identity holds:
489 //
490 // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
491 //
492 // subject to the following valid NaN propagation semantics:
493 // --------------------------------------------
494 // | OUTER CLAMP | INNER CLAMP | RESULT MODE |
495 // |-------------|--------------|-------------|
496 // | PROPAGATE | PROPAGATE | PROPAGATE |
497 // | PROPAGATE | IGNORE | IGNORE |
498 // | IGNORE | PROPAGATE | INVALID |
499 // | IGNORE | IGNORE | IGNORE |
500 // |------------------------------------------|
501 
502 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
504 
505  // Helper structure to describe the range of a clamp operation.
506  template <typename T>
507  struct ClampRange {
508  ClampRange(const T &start, const T &end) : start(start), end(end) {}
509  T start;
510  T end;
511 
512  // Helper function to determine if two Clamp ranges intersect.
513  bool intersects(const ClampRange<T> &otherRange) {
514  return start < otherRange.end && otherRange.start < end;
515  }
516  };
517 
518  LogicalResult matchAndRewrite(tosa::ClampOp op,
519  PatternRewriter &rewriter) const override {
520  Value input = op.getInput();
521 
522  // Check the input to the CLAMP op is itself a CLAMP.
523  auto clampOp = input.getDefiningOp<tosa::ClampOp>();
524  if (!clampOp)
525  return failure();
526 
527  // Check we have a valid NaN propagation combination.
528  const auto opNanMode = op.getNanMode();
529  const auto clampNanMode = clampOp.getNanMode();
530  if (opNanMode == NanPropagationMode::IGNORE &&
531  clampNanMode == NanPropagationMode::PROPAGATE)
532  return failure();
533 
534  auto maxValAttr = op.getMaxValAttr();
535  auto minValAttr = op.getMinValAttr();
536  auto clampOpMaxValAttr = clampOp.getMaxValAttr();
537  auto clampOpMinValAttr = clampOp.getMinValAttr();
538 
539  auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
540  if (auto quantType =
541  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
542  inputEType = quantType.getStorageType();
543  }
544 
545  Attribute newMinValAttr, newMaxValAttr;
546  if (mlir::isa<FloatType>(inputEType)) {
547  auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
548  auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
549  auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
550  auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
551 
552  // Check we have intersecting ranges.
553  const auto opMinFloat = floatMinValAttr.getValue();
554  const auto opMaxFloat = floatMaxValAttr.getValue();
555  const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
556  const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
557  ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
558  ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
559  clampOpMaxFloat);
560  if (!opRangeFloatRange.intersects(clampRangeFloatRange))
561  return failure();
562 
563  // Run the transformation.
564  auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
565  auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
566  newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
567  newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
568  } else {
569  assert(mlir::isa<IntegerType>(inputEType));
570  auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
571  auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
572  auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
573  auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
574 
575  if (inputEType.isUnsignedInteger()) {
576  // Check we have intersecting ranges.
577  const auto opMinInt = intMinValAttr.getUInt();
578  const auto opMaxInt = intMaxValAttr.getUInt();
579  const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
580  const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
581  ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
582  ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
583  clampOpMaxInt);
584  if (!opRangeIntRange.intersects(clampRangeIntRange))
585  return failure();
586 
587  // Run the transformation.
588  auto newMinVal = std::max(opMinInt, clampOpMinInt);
589  auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
590  newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
591  newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
592  } else {
593  // Check we have intersecting ranges.
594  const auto opMinInt = intMinValAttr.getInt();
595  const auto opMaxInt = intMaxValAttr.getInt();
596  const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
597  const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
598  ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
599  ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
600  clampOpMaxInt);
601  if (!opRangeIntRange.intersects(clampRangeIntRange))
602  return failure();
603 
604  // Run the transformation.
605  auto newMinVal = std::max(opMinInt, clampOpMinInt);
606  auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
607  newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
608  newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
609  }
610  }
611 
612  auto newMode = (opNanMode != clampNanMode)
613  ? tosa::NanPropagationMode::IGNORE
614  : opNanMode;
615 
616  auto newModeAttr =
617  NanPropagationModeAttr::get(rewriter.getContext(), newMode);
618 
619  rewriter.replaceOpWithNewOp<tosa::ClampOp>(
620  op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
621  newModeAttr);
622  return success();
623  }
624 };
625 
626 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
627  MLIRContext *context) {
628  results.add<ClampIsNoOp>(context);
629  results.add<ClampClampOptimization>(context);
630 }
631 
632 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
634 
635  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
636  PatternRewriter &rewriter) const override {
637  Value sliceInput = sliceOp.getInput1();
638  auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
639  if (!concatOp)
640  return rewriter.notifyMatchFailure(
641  sliceOp, "slice input must be concat operation");
642 
643  OperandRange inputs = concatOp.getInput1();
644  auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
645  if (!concatType || !concatType.hasStaticShape())
646  return rewriter.notifyMatchFailure(
647  sliceOp, "slice input must be a static ranked tensor");
648  int32_t axis = concatOp.getAxis();
649 
650  DenseElementsAttr startElems;
651  DenseElementsAttr sizeElems;
652 
653  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
654  return rewriter.notifyMatchFailure(
655  sliceOp, "start of slice must be a static ranked shape");
656 
657  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
658  return rewriter.notifyMatchFailure(
659  sliceOp, "size of slice must be a static ranked shape");
660 
661  llvm::SmallVector<int64_t> sliceStarts =
662  llvm::to_vector(startElems.getValues<int64_t>());
663  llvm::SmallVector<int64_t> sliceSizes =
664  llvm::to_vector(sizeElems.getValues<int64_t>());
665 
666  // Validate slice on the concatenated axis. Slicing along this
667  // axis should span only one of the inputs to the concatenate
668  // operation.
669  std::optional<Value> replaceWithSlice;
670  for (auto input : inputs) {
671  auto inputType = dyn_cast<RankedTensorType>(input.getType());
672  if (!inputType || !inputType.hasStaticShape())
673  return rewriter.notifyMatchFailure(
674  sliceOp, "concat input must be a static ranked tensor");
675 
676  if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
677  inputType.getDimSize(axis)) {
678  auto start_op =
679  getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
680  auto size_op =
681  getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
682  replaceWithSlice =
683  tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
684  input, start_op, size_op)
685  .getResult();
686  break;
687  }
688  sliceStarts[axis] -= inputType.getDimSize(axis);
689  }
690 
691  if (!replaceWithSlice)
692  return rewriter.notifyMatchFailure(
693  sliceOp, "corresponding concat input not found for slice");
694 
695  rewriter.replaceOp(sliceOp, replaceWithSlice.value());
696  return success();
697  }
698 };
699 
700 struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
702 
703  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
704  PatternRewriter &rewriter) const override {
705  Value sliceInput = sliceOp.getInput1();
706 
707  // Check if producer is a PadOp
708  auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
709  if (!padOp)
710  return rewriter.notifyMatchFailure(sliceOp,
711  "slice input must be a pad operation");
712 
713  // Check PadOp has a single consumer
714  if (!padOp->hasOneUse())
715  return rewriter.notifyMatchFailure(sliceOp,
716  "pad shall have a single consumer");
717 
718  // Check input is statically ranked
719  auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
720  auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
721  if (!inputTy || !padTy || !inputTy.hasRank())
722  return rewriter.notifyMatchFailure(sliceOp,
723  "slice input must be a ranked tensor");
724 
725  // Validate and extract tosa::PadOp padding
726  DenseIntElementsAttr paddingElems;
727  if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
728  return rewriter.notifyMatchFailure(
729  sliceOp,
730  "`padding` input specified on the tosa::PadOp must be constant.");
731  }
732  llvm::SmallVector<int64_t> padPaddings =
733  llvm::to_vector(paddingElems.getValues<int64_t>());
734 
735  // Extract slice parameters
736  DenseElementsAttr startElems;
737  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
738  return rewriter.notifyMatchFailure(
739  sliceOp, "start of slice must be a static ranked shape");
740  llvm::SmallVector<int64_t> sliceStarts =
741  llvm::to_vector(startElems.getValues<int64_t>());
742 
743  DenseElementsAttr sizeElems;
744  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
745  return rewriter.notifyMatchFailure(
746  sliceOp, "size of slice must be a static ranked shape");
747  llvm::SmallVector<int64_t> sliceSizes =
748  llvm::to_vector(sizeElems.getValues<int64_t>());
749 
750  // Check if dynamic dimensions are sliced
751  const int64_t rank = inputTy.getRank();
752  if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
753  const bool isDimDynamic = inputTy.isDynamicDim(i);
754  const bool isDimSliced =
755  (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
756 
757  return isDimDynamic && isDimSliced;
758  })) {
759  return rewriter.notifyMatchFailure(
760  sliceOp, "axis that are sliced shall be statically known.");
761  }
762 
763  // Update the parameters
764  llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
765  llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
766  llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
767  bool updated = false;
768 
769  for (int64_t i = 0; i < rank; ++i) {
770  const int64_t padLo = padPaddings[i * 2];
771  const int64_t padHi = padPaddings[i * 2 + 1];
772  const int64_t sliceStart = sliceStarts[i];
773  const int64_t sliceSize = sliceSizes[i];
774  const int64_t sliceEnd = sliceStart + sliceSize;
775 
776  // If dimension is dynamic pass-through
777  if (inputTy.isDynamicDim(i)) {
778  newPadPaddings[i * 2] = padLo;
779  newPadPaddings[i * 2 + 1] = padHi;
780  newSliceStarts[i] = sliceStart;
781  continue;
782  }
783 
784  // Handle static dimensions
785  const int64_t dimSize = inputTy.getShape()[i];
786  const int64_t dimTotal = padLo + dimSize + padHi;
787 
788  // Check slice within bounds
789  if (sliceStart < 0 || sliceEnd > dimTotal)
790  return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
791 
792  // Compute updated slice start parameter
793  const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
794  newSliceStarts[i] = newSliceStart;
795  updated |= newSliceStart != sliceStart;
796 
797  // Compute updated pad parameters
798  const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
799  const int64_t newPadHi =
800  std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
801  newPadPaddings[i * 2] = newPadLo;
802  newPadPaddings[i * 2 + 1] = newPadHi;
803  updated |= (newPadLo != padLo) || (newPadHi != padHi);
804 
805  // Calculate new pad output shape
806  newPadShape[i] =
807  newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
808  }
809 
810  // Check that we actually need to proceed with the rewrite
811  if (!updated)
812  return rewriter.notifyMatchFailure(
813  sliceOp, "terminate condition; nothing to rewrite");
814 
815  // Create a PadOp with updated padding
816  auto newPaddingsOp =
817  getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
818  auto newPadTy =
819  RankedTensorType::get(newPadShape, inputTy.getElementType());
820  auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
821  padOp.getInput1(), newPaddingsOp,
822  padOp.getPadConst());
823 
824  // Update SliceOp and point to new PadOp
825  auto newStartOp =
826  getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
827  rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
828  newPadOp.getResult(), newStartOp,
829  sliceOp.getSize());
830 
831  return success();
832  }
833 };
834 
835 // Update size operand of tosa.slice if size has dynamic dims but corresponding
836 // output dim is static
838  : public OpRewritePattern<tosa::SliceOp> {
840 
841  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
842  PatternRewriter &rewriter) const override {
843  ShapedType resultType = cast<ShapedType>(sliceOp.getType());
844 
845  ElementsAttr sizeElems;
846  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
847  return rewriter.notifyMatchFailure(
848  sliceOp, "size of slice must be a static ranked shape");
849  }
850 
851  llvm::SmallVector<int64_t> sliceSizes =
852  llvm::to_vector(sizeElems.getValues<int64_t>());
853 
854  bool replaceSliceSize{false};
855  // if size op has -1 indicating dynamic shape but corresponding dim on the
856  // output is statically known, update size to match with known output dim
857  // shape
858  for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
859  if (size == -1 && !resultType.isDynamicDim(index)) {
860  sliceSizes[index] = resultType.getDimSize(index);
861  replaceSliceSize = true;
862  }
863  }
864 
865  if (!replaceSliceSize) {
866  return rewriter.notifyMatchFailure(
867  sliceOp, "no dimension of size of slice is dynamic that resolves "
868  "to static output shape");
869  }
870 
871  auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
872  auto newSliceOp =
873  tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
874  sliceOp.getInput1(), sliceOp.getStart(), size_op);
875 
876  rewriter.replaceOp(sliceOp, newSliceOp.getResult());
877  return success();
878  }
879 };
880 
881 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
882  MLIRContext *context) {
885 }
886 
887 //===----------------------------------------------------------------------===//
888 // Operator Folders.
889 //===----------------------------------------------------------------------===//
890 
891 template <typename IntFolder, typename FloatFolder>
893  DenseElementsAttr rhs,
894  RankedTensorType returnTy) {
895  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
896  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
897  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
898  if (lETy != rETy)
899  return {};
900 
901  if (llvm::isa<IntegerType>(lETy)) {
902  APInt l = lhs.getSplatValue<APInt>();
903  APInt r = rhs.getSplatValue<APInt>();
904  auto result = IntFolder()(l, r);
905  return DenseElementsAttr::get(returnTy, result);
906  }
907 
908  if (llvm::isa<FloatType>(lETy)) {
909  APFloat l = lhs.getSplatValue<APFloat>();
910  APFloat r = rhs.getSplatValue<APFloat>();
911  auto result = FloatFolder()(l, r);
912  return DenseElementsAttr::get(returnTy, result);
913  }
914  }
915 
916  return {};
917 }
918 
919 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
920  if (llvm::isa<FloatType>(elemType))
921  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
922  if (llvm::isa<IntegerType>(elemType))
923  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
924  return false;
925 }
926 
927 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
928  if (llvm::isa<FloatType>(elemType))
929  return val && val.isSplat() &&
930  val.getSplatValue<APFloat>().isExactlyValue(1.0);
931  if (llvm::isa<IntegerType>(elemType)) {
932  const int64_t shifted = 1LL << shift;
933  return val && val.isSplat() &&
934  val.getSplatValue<APInt>().getSExtValue() == shifted;
935  }
936  return false;
937 }
938 
939 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
940  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
941  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
942  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
943  if (!lhsTy || !rhsTy || !resultTy)
944  return {};
945 
946  // Cannot create an ElementsAttr from non-int/float/index types
947  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
948  !rhsTy.getElementType().isIntOrIndexOrFloat())
949  return {};
950 
951  auto resultETy = resultTy.getElementType();
952  auto lhsAttr =
953  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
954  auto rhsAttr =
955  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
956 
957  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
958  return getInput1();
959  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
960  return getInput2();
961 
962  if (!lhsAttr || !rhsAttr)
963  return {};
964 
965  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
966  resultTy);
967 }
968 
969 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
970  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
971  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
972  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
973  !outputTy.hasStaticShape())
974  return {};
975 
976  const Type outputElementTy = getElementTypeOrSelf(outputTy);
977  if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
978  const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
979  const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
980  return DenseElementsAttr::get(outputTy, zero);
981  }
982 
983  return {};
984 }
985 
986 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
987  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
988  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
989  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
990  if (!lhsTy || !rhsTy || !resultTy)
991  return {};
992  if (lhsTy != rhsTy)
993  return {};
994 
995  // IntDivOp inputs must be integer type, no need to check for quantized type
996  auto resultETy = resultTy.getElementType();
997  auto lhsAttr =
998  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
999  auto rhsAttr =
1000  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1001  if (lhsAttr && lhsAttr.isSplat()) {
1002  if (llvm::isa<IntegerType>(resultETy) &&
1003  lhsAttr.getSplatValue<APInt>().isZero())
1004  return lhsAttr;
1005  }
1006 
1007  if (rhsAttr && rhsAttr.isSplat()) {
1008  if (llvm::isa<IntegerType>(resultETy) &&
1009  rhsAttr.getSplatValue<APInt>().isOne())
1010  return getInput1();
1011  }
1012 
1013  if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1014  llvm::isa<IntegerType>(resultETy)) {
1015  APInt l = lhsAttr.getSplatValue<APInt>();
1016  APInt r = rhsAttr.getSplatValue<APInt>();
1017  if (!r.isZero()) {
1018  APInt result = l.sdiv(r);
1019  return DenseElementsAttr::get(resultTy, result);
1020  }
1021  }
1022 
1023  return {};
1024 }
1025 
1026 namespace {
1027 // calculate lhs * rhs >> shift according to TOSA Spec
1028 // return nullopt if result is not in range of int32_t when shift > 0
1029 std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1030  unsigned bitwidth) {
1031  APInt result = lhs.sext(64) * rhs.sext(64);
1032 
1033  if (shift > 0) {
1034  auto round = APInt(64, 1) << (shift - 1);
1035  result += round;
1036  result.ashrInPlace(shift);
1037  // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
1038  if (!(result.getSExtValue() >= INT32_MIN &&
1039  result.getSExtValue() <= INT32_MAX)) {
1040  // REQUIRE failed
1041  return std::nullopt;
1042  }
1043  }
1044 
1045  return result.trunc(bitwidth);
1046 }
1047 
1048 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1049  RankedTensorType ty, int32_t shift) {
1050  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1051  if (llvm::isa<IntegerType>(ty.getElementType())) {
1052  APInt l = lhs.getSplatValue<APInt>();
1053  APInt r = rhs.getSplatValue<APInt>();
1054 
1055  if (shift == 0) {
1056  return DenseElementsAttr::get(ty, l * r);
1057  }
1058 
1059  auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1060  const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1061  if (!result)
1062  return {};
1063  return DenseElementsAttr::get(ty, result.value());
1064  }
1065 
1066  if (llvm::isa<FloatType>(ty.getElementType())) {
1067  APFloat l = lhs.getSplatValue<APFloat>();
1068  APFloat r = rhs.getSplatValue<APFloat>();
1069  APFloat result = l * r;
1070  return DenseElementsAttr::get(ty, result);
1071  }
1072  }
1073 
1074  return {};
1075 }
1076 } // namespace
1077 
1078 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1079  auto lhs = getInput1();
1080  auto rhs = getInput2();
1081  auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1082  auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1083  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1084  if (!lhsTy || !rhsTy || !resultTy)
1085  return {};
1086 
1087  auto resultETy = resultTy.getElementType();
1088  auto lhsAttr =
1089  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1090  auto rhsAttr =
1091  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1092 
1093  // Result right shift on i32_t data type only. For simplification, synthesize
1094  // a zero shift for other data type.
1095  int32_t shift = 0;
1096  if (resultETy.isInteger(32)) {
1097  ElementsAttr shift_elem;
1098  if (getShift().getImpl()) {
1099  if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1100  // cannot be folded when the shift value is unknown.
1101  return {};
1102  shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1103  }
1104  }
1105 
1106  if (rhsTy == resultTy) {
1107  if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1108  // constant values can only be resized if resulting type is static
1109  return lhsAttr.resizeSplat(resultTy);
1110  if (isSplatOne(resultETy, lhsAttr, shift))
1111  return rhs;
1112  }
1113  if (lhsTy == resultTy) {
1114  if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1115  return rhsAttr.resizeSplat(resultTy);
1116  if (isSplatOne(resultETy, rhsAttr, shift))
1117  return lhs;
1118  }
1119 
1120  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1121 }
1122 
1123 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1124  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1125  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1126  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1127  if (!lhsTy || !rhsTy || !resultTy)
1128  return {};
1129 
1130  // Cannot create an ElementsAttr from non-int/float/index types
1131  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1132  !rhsTy.getElementType().isIntOrIndexOrFloat())
1133  return {};
1134 
1135  auto resultETy = resultTy.getElementType();
1136  auto lhsAttr =
1137  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1138  auto rhsAttr =
1139  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1140 
1141  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1142  return getInput1();
1143 
1144  if (!lhsAttr || !rhsAttr)
1145  return {};
1146 
1147  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
1148  resultTy);
1149 }
1150 
1151 namespace {
1152 template <typename Cmp>
1153 struct ComparisonFold {
1154  ComparisonFold() = default;
1155  APInt operator()(const APInt &l, const APInt &r) {
1156  return APInt(1, Cmp()(l, r));
1157  }
1158 
1159  APInt operator()(const APFloat &l, const APFloat &r) {
1160  return APInt(1, Cmp()(l, r));
1161  }
1162 };
1163 
1164 struct APIntFoldGreater {
1165  APIntFoldGreater() = default;
1166  APInt operator()(const APInt &l, const APInt &r) {
1167  return APInt(1, l.sgt(r));
1168  }
1169 };
1170 
1171 struct APIntFoldGreaterEqual {
1172  APIntFoldGreaterEqual() = default;
1173  APInt operator()(const APInt &l, const APInt &r) {
1174  return APInt(1, l.sge(r));
1175  }
1176 };
1177 } // namespace
1178 
1179 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1180  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1181  auto lhsAttr =
1182  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1183  auto rhsAttr =
1184  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1185 
1186  if (!lhsAttr || !rhsAttr)
1187  return {};
1188 
1189  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
1190  lhsAttr, rhsAttr, resultTy);
1191 }
1192 
1193 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1194  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1195  auto lhsAttr =
1196  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1197  auto rhsAttr =
1198  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1199 
1200  if (!lhsAttr || !rhsAttr)
1201  return {};
1202 
1203  return binaryFolder<APIntFoldGreaterEqual,
1204  ComparisonFold<std::greater_equal<APFloat>>>(
1205  lhsAttr, rhsAttr, resultTy);
1206 }
1207 
1208 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1209  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1210  auto lhsAttr =
1211  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1212  auto rhsAttr =
1213  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1214  Value lhs = getInput1();
1215  Value rhs = getInput2();
1216  auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1217 
1218  // If we are comparing an integer value to itself it is always true. We can
1219  // not do this with float due to float values.
1220  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1221  resultTy.hasStaticShape() && lhs == rhs) {
1222  return DenseElementsAttr::get(resultTy, true);
1223  }
1224 
1225  if (!lhsAttr || !rhsAttr)
1226  return {};
1227 
1228  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
1229  ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1230  resultTy);
1231 }
1232 
1233 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1234  if (getInput().getType() == getType())
1235  return getInput();
1236 
1237  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1238  if (!operand)
1239  return {};
1240 
1241  auto inTy = llvm::cast<ShapedType>(getInput().getType());
1242  auto outTy = llvm::cast<ShapedType>(getType());
1243  auto inETy = inTy.getElementType();
1244  auto outETy = outTy.getElementType();
1245 
1246  if (operand.isSplat()) {
1247  if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1248  bool overflow;
1249  auto splatVal = operand.getSplatValue<APFloat>();
1250  auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1251  splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1252  &overflow);
1253  return SplatElementsAttr::get(outTy, splatVal);
1254  }
1255 
1256  if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1257  auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1258  APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1259  splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1260  llvm::RoundingMode::NearestTiesToEven);
1261  return SplatElementsAttr::get(outTy, splatVal);
1262  }
1263 
1264  if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1265  auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1266  auto intVal = APSInt(
1267  llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1268  auto floatVal = operand.getSplatValue<APFloat>();
1269  bool exact;
1270  floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1271  &exact);
1272  return SplatElementsAttr::get(outTy, intVal);
1273  }
1274 
1275  if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1276  const auto inIntType = llvm::cast<IntegerType>(inETy);
1277  auto unsignIn = inIntType.isUnsignedInteger();
1278  bool trunc =
1279  inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1280  auto intVal = operand.getSplatValue<APInt>();
1281  auto bitwidth = outETy.getIntOrFloatBitWidth();
1282 
1283  // i1 types are boolean in TOSA
1284  if (outETy.isInteger(1)) {
1285  intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1286  } else if (trunc) {
1287  intVal = intVal.trunc(bitwidth);
1288  } else if (unsignIn || inIntType.isInteger(1)) {
1289  intVal = intVal.zext(bitwidth);
1290  } else {
1291  intVal = intVal.sext(bitwidth);
1292  }
1293 
1294  return SplatElementsAttr::get(outTy, intVal);
1295  }
1296  }
1297 
1298  return {};
1299 }
1300 
1301 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1302 
1303 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1304 
1305 #define REDUCE_FOLDER(OP) \
1306  OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1307  ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1308  if (!inputTy.hasRank()) \
1309  return {}; \
1310  if (inputTy != getType()) \
1311  return {}; \
1312  if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1313  return getInput(); \
1314  return {}; \
1315  }
1316 
1317 REDUCE_FOLDER(ReduceAllOp)
1318 REDUCE_FOLDER(ReduceAnyOp)
1319 REDUCE_FOLDER(ReduceMaxOp)
1320 REDUCE_FOLDER(ReduceMinOp)
1321 REDUCE_FOLDER(ReduceProductOp)
1322 REDUCE_FOLDER(ReduceSumOp)
1323 #undef REDUCE_FOLDER
1324 
1325 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1326  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1327  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1328 
1329  if (!inputTy || !outputTy)
1330  return {};
1331 
1332  // Fold when the input and output types are the same. This is only safe when
1333  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
1334  // there may still be a productive reshape.
1335  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1336  return getInput1();
1337 
1338  // reshape(reshape(x)) -> reshape(x)
1339  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1340  getInput1().getDefiningOp())) {
1341  getInput1Mutable().assign(reshapeOp.getInput1());
1342  return getResult();
1343  }
1344 
1345  // Cannot create an ElementsAttr from non-int/float/index types
1346  if (!inputTy.getElementType().isIntOrIndexOrFloat())
1347  return {};
1348 
1349  // reshape(const(x)) -> const(reshape-attr(x))
1350  if (auto operand =
1351  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1352  // Constants must have static shape.
1353  if (!outputTy.hasStaticShape())
1354  return {};
1355 
1356  // Okay to duplicate splat constants.
1357  if (operand.isSplat())
1358  return SplatElementsAttr::get(outputTy,
1359  operand.getSplatValue<Attribute>());
1360 
1361  // Don't duplicate other constants.
1362  if (!getInput1().hasOneUse())
1363  return {};
1364 
1365  llvm::SmallVector<int64_t> shapeVec;
1366  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1367  return {};
1368 
1369  return operand.reshape(
1370  llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1371  }
1372 
1373  return {};
1374 }
1375 
1376 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1377  // If the pad is all zeros we can fold this operation away.
1378  if (adaptor.getPadding() && getInput1().getType() == getType()) {
1379  auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1380  if (densePad && densePad.isSplat() &&
1381  densePad.getSplatValue<APInt>().isZero()) {
1382  return getInput1();
1383  }
1384  }
1385 
1386  return {};
1387 }
1388 
1389 // Fold away cases where a tosa.resize operation returns a copy
1390 // of the input image.
1391 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1392  auto scaleAttr =
1393  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1394  auto offsetAttr =
1395  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1396  auto borderAttr =
1397  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1398  if (!scaleAttr || !offsetAttr || !borderAttr) {
1399  return {};
1400  }
1401 
1402  auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1403  auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1404  auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1405  if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1406  return {};
1407  }
1408 
1409  // Check unit scaling.
1410  if (scale[0] != scale[1] || scale[2] != scale[3]) {
1411  return {};
1412  }
1413 
1414  // There should be no offset.
1415  if (offset[0] != 0 || offset[1] != 0) {
1416  return {};
1417  }
1418 
1419  // There should be no border.
1420  if (border[0] != 0 || border[1] != 0) {
1421  return {};
1422  }
1423 
1424  auto input = getInput();
1425  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1426  auto resultTy = llvm::cast<RankedTensorType>(getType());
1427  if (inputTy != resultTy)
1428  return {};
1429 
1430  return input;
1431 }
1432 
1433 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1434  auto operand = getInput1();
1435  auto operandTy = llvm::cast<ShapedType>(operand.getType());
1436  auto axis = getAxis();
1437  auto operandAttr =
1438  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1439  if (operandAttr)
1440  return operandAttr;
1441 
1442  // If the dim-length is 1, tosa.reverse is a no-op.
1443  if (operandTy.hasRank() &&
1444  (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1445  return operand;
1446 
1447  return {};
1448 }
1449 
1450 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1451  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1452  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1453 
1454  if (!inputTy || !outputTy)
1455  return {};
1456 
1457  if (inputTy == outputTy && inputTy.hasStaticShape())
1458  return getInput1();
1459 
1460  if (!adaptor.getInput1())
1461  return {};
1462 
1463  // Cannot create an ElementsAttr from non-int/float/index types
1464  if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1465  !outputTy.getElementType().isIntOrIndexOrFloat())
1466  return {};
1467 
1468  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1469  if (operand.isSplat() && outputTy.hasStaticShape()) {
1470  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1471  }
1472 
1473  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1474  outputTy.getNumElements() == 1) {
1475  DenseElementsAttr startElems;
1476  if (!matchPattern(getStart(), m_Constant(&startElems)))
1477  return {};
1478 
1479  llvm::SmallVector<uint64_t> indices =
1480  llvm::to_vector(startElems.getValues<uint64_t>());
1481  auto value = operand.getValues<Attribute>()[indices];
1482  return SplatElementsAttr::get(outputTy, value);
1483  }
1484 
1485  return {};
1486 }
1487 
1488 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1489  if (getOnTrue() == getOnFalse())
1490  return getOnTrue();
1491 
1492  auto predicate =
1493  llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1494  if (!predicate)
1495  return {};
1496 
1497  if (!predicate.isSplat())
1498  return {};
1499  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1500  : getOnFalse();
1501 }
1502 
1503 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1504  if (getInput1().getType() == getType()) {
1505  if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1506  adaptor.getMultiples())) {
1507  if (multiples.isSplat() &&
1508  multiples.getSplatValue<APInt>().getSExtValue() == 1)
1509  return getInput1();
1510  if (auto int_array_attr =
1511  llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1512  if (llvm::all_of(int_array_attr.getValues<APInt>(),
1513  [](APInt v) { return v.getSExtValue() == 1; }))
1514  return getInput1();
1515  }
1516  }
1517  }
1518  return {};
1519 }
1520 
1521 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1522  auto resultTy = llvm::cast<ShapedType>(getType());
1523 
1524  // Transposing splat values just means reshaping.
1525  if (auto input =
1526  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1527  if (input.isSplat() && resultTy.hasStaticShape() &&
1528  input.getType().getElementType() == resultTy.getElementType())
1529  return input.reshape(resultTy);
1530  }
1531 
1532  // Transpose is not the identity transpose.
1533  const llvm::ArrayRef<int32_t> perms = getPerms();
1534 
1535  if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1536  return {};
1537 
1538  return getInput1();
1539 }
1540 
1541 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1542  // Element-wise negate(negate(x)) = x
1543  // iff all zero points are constant 0
1544  auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1545  if (!definingOp) {
1546  // defining op of input1 is not a negate, cannot fold
1547  return {};
1548  }
1549 
1550  if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1551  failed(maybeIZp) || *maybeIZp != 0) {
1552  // input1 zero point is not constant 0, cannot fold
1553  return {};
1554  }
1555  if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1556  failed(maybeOZp) || *maybeOZp != 0) {
1557  // output zero point is not constant 0, cannot fold
1558  return {};
1559  }
1560  if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1561  failed(maybeIZp) || *maybeIZp != 0) {
1562  // definingOp's input1 zero point is not constant 0, cannot fold
1563  return {};
1564  }
1565  if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1566  failed(maybeOZp) || *maybeOZp != 0) {
1567  // definingOp's output zero point is not constant 0, cannot fold
1568  return {};
1569  }
1570 
1571  return definingOp.getInput1();
1572 }
1573 
1574 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1575  auto input = getInput1();
1576  // Element-wise abs(abs(x)) = abs(x)
1577  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1578  return input;
1579  }
1580 
1581  return {};
1582 }
1583 
1584 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1585  // Fold consecutive concats on the same axis into a single op.
1586  // Keep track of the operands so we are able to construct a new concat
1587  // later. Conservatively assume that we double the number of operands when
1588  // folding
1589  SmallVector<Value, 8> concatOperands;
1590  concatOperands.reserve(2 * getNumOperands());
1591 
1592  // Find all operands that are foldable concats
1593  bool foundFoldableConcat = false;
1594  for (Value operand : getOperands()) {
1595  concatOperands.emplace_back(operand);
1596 
1597  auto producer = operand.getDefiningOp<ConcatOp>();
1598  if (!producer)
1599  continue;
1600 
1601  // Not foldable if axes are not the same
1602  if (getAxis() != producer.getAxis())
1603  continue;
1604 
1605  // Replace the original operand with all incoming operands
1606  foundFoldableConcat = true;
1607  concatOperands.pop_back();
1608  llvm::append_range(concatOperands, producer->getOperands());
1609  }
1610 
1611  if (!foundFoldableConcat)
1612  return {};
1613 
1614  getOperation()->setOperands(concatOperands);
1615  return getResult();
1616 }
1617 
1618 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1619  auto input = adaptor.getInput1();
1620 
1621  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1622  // Fold splat inputs only.
1623  if (!inputAttr || !inputAttr.isSplat())
1624  return {};
1625 
1626  auto shapeType = llvm::cast<ShapedType>(getType());
1627  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1628  auto floatVal = inputAttr.getSplatValue<APFloat>();
1629  return DenseElementsAttr::get(shapeType,
1630  ReciprocalOp::calcOneElement(floatVal));
1631  }
1632 
1633  return {};
1634 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define REDUCE_FOLDER(OP)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Attributes are known-constant values of operations.
Definition: Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:167
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
MLIRContext * getContext() const
Definition: Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322