MLIR  22.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // TOSA canonicalization patterns and folders.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/APInt.h"
26 
27 #include <functional>
28 
29 using namespace mlir;
30 using namespace mlir::tosa;
31 
32 //===----------------------------------------------------------------------===//
33 // Operator Canonicalizers.
34 //===----------------------------------------------------------------------===//
35 
36 //===----------------------------------------------------------------------===//
37 // Tensor Data Engine Operators.
38 //===----------------------------------------------------------------------===//
39 
40 // Check that the zero point of the tensor and padding operations are aligned.
42  // Check that padConst is a constant value and a scalar tensor
43  DenseElementsAttr padConstAttr;
44  if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
45  (padConstAttr.size() != 1)) {
46  return false;
47  }
48 
49  // Check that floating point pad is zero
50  if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
51  float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
52  return padConstVal == 0.0f;
53  }
54 
55  // Check that the zp and padConst align for the integer (quantized) case
56  if (auto padConstIntAttr =
57  mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
58  DenseIntElementsAttr zpAttr;
59  // Check that zp is a constant value and a scalar tensor
60  if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
61  return false;
62  }
63 
64  // Check equality
65  int64_t zpVal = (*zpAttr.begin()).getSExtValue();
66  int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
67  return zpVal == padConstVal;
68  }
69 
70  // Bail-out on unsupported type
71  return false;
72 }
73 
74 namespace {
75 template <typename OpTy>
76 struct PoolPadFoldAdaptor;
77 
78 template <>
79 struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
80  using OpTy = tosa::AvgPool2dOp;
81  static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
82  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
83  if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
84  newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
85  return false;
86  return true;
87  }
88  static bool checkPadConstCompliance(OpTy op, Value padConst) {
89  return checkMatchingPadConstAndZp(padConst, op.getInputZp());
90  }
91  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
92  Value padInput, ArrayRef<int64_t> newPad) {
93  rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
94  op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
95  op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
96  op.getAccType());
97  }
98 };
99 
100 template <>
101 struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
102  using OpTy = tosa::MaxPool2dOp;
103  static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
104  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
105  if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
106  newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
107  return false;
108  return true;
109  }
110  static bool checkPadConstCompliance(OpTy, Value padConst) {
111  // Check that padConst is a constant value and a scalar tensor
112  DenseElementsAttr padConstAttr;
113  if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
114  padConstAttr.size() != 1) {
115  return false;
116  }
117 
118  // Pad needs to be in the minimum value to be able to merge
119  if (auto padConstFpAttr =
120  mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
121  const APFloat padConstVal = *padConstFpAttr.begin();
122  const APFloat lowestVal =
123  APFloat::getLargest(padConstVal.getSemantics(), true);
124  return padConstVal == lowestVal;
125  }
126  if (auto padConstIntAttr =
127  mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
128  const APInt padConstVal = *padConstIntAttr.begin();
129  const unsigned int bitWidth = padConstVal.getBitWidth();
130  const APInt lowestVal =
131  padConstIntAttr.getElementType().isUnsignedInteger()
132  ? APInt::getZero(bitWidth)
133  : APInt::getSignedMinValue(bitWidth);
134  return padConstVal == lowestVal;
135  }
136 
137  // Bail-out on unsupported type
138  return false;
139  }
140  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
141  Value padInput, ArrayRef<int64_t> newPad) {
142  rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
143  op, op.getType(), padInput, op.getKernel(), op.getStride(),
144  rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
145  }
146 };
147 
148 template <typename OpTy>
149 struct ConvPadFoldAdaptor {
150  static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
151  return true;
152  }
153  static bool checkPadConstCompliance(OpTy op, Value padConst) {
154  return checkMatchingPadConstAndZp(padConst, op.getInputZp());
155  }
156  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
157  Value padInput, ArrayRef<int64_t> newPad) {
158  rewriter.replaceOpWithNewOp<OpTy>(
159  op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
160  op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
161  op.getDilationAttr(), op.getAccType(), op.getLocalBound());
162  }
163 };
164 
165 // Pattern attempts to fold a `tosa.pad` operator to a following tensor
166 // operation like `tosa.conv2d` by merging the padding associated with the
167 // pad operator directly to the implicit padding of the tensor operation.
168 // This helps eliminate the explicit padding operator if unused.
169 template <typename OpTy, typename AdaptorTy>
170 struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
172 
173  LogicalResult matchAndRewrite(OpTy tensorOp,
174  PatternRewriter &rewriter) const override {
175  // Check producer is a tosa::PadOp
176  auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
177  if (!padOp)
178  return rewriter.notifyMatchFailure(tensorOp,
179  "Producer must be a tosa::PadOp.");
180 
181  // Validate that tensor operation has sane padding
182  const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
183  if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
184  return rewriter.notifyMatchFailure(
185  tensorOp, "Tensor operation padding shall have 4 elements.");
186 
187  // Validate tosa::PadOp padding
188  DenseIntElementsAttr padOpPadding;
189  if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
190  return rewriter.notifyMatchFailure(
191  tensorOp,
192  "The `padding` input specified on the tosa::PadOp must be constant.");
193  }
194  // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
195  // C_after
196  if (padOpPadding.size() != 8)
197  return rewriter.notifyMatchFailure(tensorOp,
198  "Pad padding should have 8 elements.");
199  int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
200  int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
201  int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
202  int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
203  int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
204  int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
205  int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
206  int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
207 
208  if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
209  return rewriter.notifyMatchFailure(
210  tensorOp, "Folding padding in N or C dimensions is not supported.");
211 
212  // Fold padding from Pad into the tensor operation
213  // 4 elements - pad_top, pad_bottom, pad_left, pad_right
214  SmallVector<int64_t> foldedPad(tensorOpPad.size());
215  foldedPad[0] = padHBefore + tensorOpPad[0];
216  foldedPad[1] = padHAfter + tensorOpPad[1];
217  foldedPad[2] = padWBefore + tensorOpPad[2];
218  foldedPad[3] = padWAfter + tensorOpPad[3];
219 
220  // Check kernel related restrictions
221  if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
222  return rewriter.notifyMatchFailure(
223  tensorOp, "Padding size not aligned with kernel restrictions.");
224  }
225 
226  // Check padding constant restrictions
227  if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
228  return rewriter.notifyMatchFailure(
229  tensorOp,
230  "Padding constant is not aligned with operator zero-point.");
231  }
232 
233  // Check that padding doesn't grow more than 8K level (8192) for now
234  if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
235  return rewriter.notifyMatchFailure(
236  tensorOp, "Padding size more than the 8K level limit.");
237  }
238 
239  // Create operator
240  AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
241  foldedPad);
242 
243  return success();
244  }
245 };
246 } // namespace
247 
248 void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
249  MLIRContext *context) {
250  results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
251  PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
252  context);
253 }
254 
255 void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
256  MLIRContext *context) {
257  results.add<
258  FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
259  context);
260 }
261 
262 void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
263  MLIRContext *context) {
264  results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
265  ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
266  context);
267 }
268 
269 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
271 
272  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
273  PatternRewriter &rewriter) const override {
274  Value input = op.getInput();
275  Value output = op.getOutput();
276  ShapedType inputType = llvm::cast<ShapedType>(input.getType());
277  ShapedType outputType = llvm::cast<ShapedType>(output.getType());
278 
279  if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
280  return failure();
281  }
282 
283  // If the output and input shapes are 1x1, then this is a no op.
284  ArrayRef<int64_t> outputShape = outputType.getShape();
285  if (outputShape[1] != 1 || outputShape[2] != 1) {
286  return failure();
287  }
288 
289  ArrayRef<int64_t> inputShape = inputType.getShape();
290  if (inputShape[1] != 1 || inputShape[2] != 1) {
291  return failure();
292  }
293 
294  rewriter.replaceOp(op, input);
295  return success();
296  }
297 };
298 
299 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
300  MLIRContext *context) {
301  results.add<MaxPool2dIsNoOp,
302  FoldPadToTensorOp<tosa::MaxPool2dOp,
303  PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
304  context);
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // Data Layout / Memory Reinterpretation.
309 //===----------------------------------------------------------------------===//
310 
311 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
313 
314  LogicalResult matchAndRewrite(tosa::ConcatOp op,
315  PatternRewriter &rewriter) const override {
316  if (op.getInput1().size() != 1)
317  return failure();
318  if (op.getInput1().front().getType() != op.getType()) {
319  rewriter
320  .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
321  op.getInput1().front())
322  .getResult();
323  return success();
324  }
325 
326  rewriter.replaceOp(op, op.getInput1().front());
327  return success();
328  }
329 };
330 
331 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
332  MLIRContext *context) {
333  results.add<ConcatOptimization>(context);
334 }
335 
336 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
337  auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
338  if (!notOp)
339  return failure();
340  rewriter.modifyOpInPlace(op, [&]() {
341  op.getOperation()->setOperands(
342  {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
343  });
344  return success();
345 }
346 
348  : public OpRewritePattern<tosa::TransposeOp> {
350 
351  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
352  PatternRewriter &rewriter) const override {
353  // Input is also TransposeOp - transpose(transpose(A)).
354  auto innerTranspose =
355  transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
356  if (!innerTranspose)
357  return rewriter.notifyMatchFailure(transposeOp,
358  "input must be transpose operation");
359 
360  const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
361  const llvm::ArrayRef<int32_t> innerTransposePerms =
362  innerTranspose.getPerms();
363 
364  if (transposePerms.size() != innerTransposePerms.size())
365  return rewriter.notifyMatchFailure(
366  transposeOp,
367  "transpose and inner transpose perms sizes must be equal");
368  if (transposePerms.empty())
369  return rewriter.notifyMatchFailure(
370  transposeOp, "transpose perms sizes must be positive");
371 
372  // Consolidate transposes into one transpose.
373  SmallVector<int32_t> perms(transposePerms.size());
374  for (int i = 0, s = transposePerms.size(); i < s; ++i)
375  perms[i] = innerTransposePerms[transposePerms[i]];
376 
377  rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
378  transposeOp, transposeOp.getResult().getType(),
379  innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
380 
381  return success();
382  }
383 };
384 
385 // Determines the case when tosa.transpose is a tosa.reshape operation.
386 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
388 
389  LogicalResult matchAndRewrite(tosa::TransposeOp op,
390  PatternRewriter &rewriter) const override {
391  if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
392  return rewriter.notifyMatchFailure(
393  op, "Src is from transpose, can compose transposes");
394 
395  Value result = op.getResult();
396  for (Operation *subop : result.getUsers()) {
397  if (isa_and_nonnull<tosa::TransposeOp>(subop))
398  return rewriter.notifyMatchFailure(
399  op, "Dest is used by transpose, can compose transposes");
400  }
401 
402  auto input = op.getInput1();
403  auto inputTy = llvm::cast<ShapedType>(input.getType());
404  if (!inputTy.hasRank())
405  return rewriter.notifyMatchFailure(op, "Unranked input.");
406 
407  int64_t numDynDims = 0;
408  for (int i = 0; i < inputTy.getRank(); ++i)
409  if (inputTy.isDynamicDim(i))
410  numDynDims++;
411 
412  if (numDynDims > 1)
413  return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
414 
415  const llvm::ArrayRef<int32_t> permValues = op.getPerms();
416 
417  SmallVector<int64_t> nonZeroPerms;
418  nonZeroPerms.reserve(permValues.size());
419  for (auto idx : permValues) {
420  auto sz = inputTy.getDimSize(idx);
421  if (sz != 1)
422  nonZeroPerms.push_back(idx);
423  }
424 
425  for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
426  if (nonZeroPerms[i - 1] > nonZeroPerms[i])
427  return rewriter.notifyMatchFailure(op,
428  "Transpose changes memory layout.");
429 
430  SmallVector<int64_t> newShape;
431  newShape.reserve(inputTy.getRank());
432  for (int i = 0, s = inputTy.getRank(); i < s; ++i)
433  newShape.push_back(inputTy.getDimSize(permValues[i]));
434 
435  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
436  op, op.getType(), op.getInput1(),
437  getTosaConstShape(rewriter, op.getLoc(), newShape));
438  return success();
439  }
440 };
441 
442 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
443  MLIRContext *context) {
445 }
446 
447 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
449 
450  LogicalResult matchAndRewrite(tosa::ClampOp op,
451  PatternRewriter &rewriter) const override {
452  Value input = op.getInput();
453  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
454  auto inputElementType = inputType.getElementType();
455 
456  if (isa<FloatType>(inputElementType)) {
457  // Unlike integer types, floating point types can represent infinity.
458  const auto minClamp =
459  llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
460  const auto maxClamp =
461  llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
462  const bool isMin = minClamp.isNegInfinity();
463  const bool isMax = maxClamp.isInfinity();
464 
465  if (isMin && isMax) {
466  rewriter.replaceOp(op, input);
467  return success();
468  }
469  return failure();
470  }
471 
472  // i1 types are boolean in TOSA
473  const bool isBoolean = inputElementType.isInteger(1);
474  if (inputElementType.isUnsignedInteger() || isBoolean) {
475  const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
476  .getValue()
477  .getZExtValue();
478  const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
479  .getValue()
480  .getZExtValue();
481 
482  const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
483  const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
484  const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
485 
486  if (minClamp <= intMin && maxClamp >= intMax) {
487  rewriter.replaceOp(op, input);
488  return success();
489  }
490  return failure();
491  }
492 
493  if (llvm::isa<IntegerType>(inputElementType)) {
494  const int64_t minClamp =
495  llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
496  const int64_t maxClamp =
497  llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
498 
499  const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
500  const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
501  const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
502 
503  if (minClamp <= intMin && maxClamp >= intMax) {
504  rewriter.replaceOp(op, input);
505  return success();
506  }
507  return failure();
508  }
509 
510  return failure();
511  }
512 };
513 
514 // Attempts the following transformation:
515 //
516 // For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
517 // tensor X the following identity holds:
518 //
519 // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
520 //
521 // subject to the following valid NaN propagation semantics:
522 // --------------------------------------------
523 // | OUTER CLAMP | INNER CLAMP | RESULT MODE |
524 // |-------------|--------------|-------------|
525 // | PROPAGATE | PROPAGATE | PROPAGATE |
526 // | PROPAGATE | IGNORE | IGNORE |
527 // | IGNORE | PROPAGATE | INVALID |
528 // | IGNORE | IGNORE | IGNORE |
529 // |------------------------------------------|
530 
531 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
533 
534  // Helper structure to describe the range of a clamp operation.
535  template <typename T>
536  struct ClampRange {
537  ClampRange(const T &start, const T &end) : start(start), end(end) {}
538  T start;
539  T end;
540 
541  // Helper function to determine if two Clamp ranges intersect.
542  bool intersects(const ClampRange<T> &otherRange) {
543  return start < otherRange.end && otherRange.start < end;
544  }
545  };
546 
547  LogicalResult matchAndRewrite(tosa::ClampOp op,
548  PatternRewriter &rewriter) const override {
549  Value input = op.getInput();
550 
551  // Check the input to the CLAMP op is itself a CLAMP.
552  auto clampOp = input.getDefiningOp<tosa::ClampOp>();
553  if (!clampOp)
554  return failure();
555 
556  // Check we have a valid NaN propagation combination.
557  const auto opNanMode = op.getNanMode();
558  const auto clampNanMode = clampOp.getNanMode();
559  if (opNanMode == NanPropagationMode::IGNORE &&
560  clampNanMode == NanPropagationMode::PROPAGATE)
561  return failure();
562 
563  auto maxValAttr = op.getMaxValAttr();
564  auto minValAttr = op.getMinValAttr();
565  auto clampOpMaxValAttr = clampOp.getMaxValAttr();
566  auto clampOpMinValAttr = clampOp.getMinValAttr();
567 
568  auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
569  if (auto quantType =
570  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
571  inputEType = quantType.getStorageType();
572  }
573 
574  Attribute newMinValAttr, newMaxValAttr;
575  if (mlir::isa<FloatType>(inputEType)) {
576  auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
577  auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
578  auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
579  auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
580 
581  // Check we have intersecting ranges.
582  const auto opMinFloat = floatMinValAttr.getValue();
583  const auto opMaxFloat = floatMaxValAttr.getValue();
584  const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
585  const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
586  ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
587  ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
588  clampOpMaxFloat);
589  if (!opRangeFloatRange.intersects(clampRangeFloatRange))
590  return failure();
591 
592  // Run the transformation.
593  auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
594  auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
595  newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
596  newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
597  } else {
598  assert(mlir::isa<IntegerType>(inputEType));
599  auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
600  auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
601  auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
602  auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
603 
604  if (inputEType.isUnsignedInteger()) {
605  // Check we have intersecting ranges.
606  const auto opMinInt = intMinValAttr.getUInt();
607  const auto opMaxInt = intMaxValAttr.getUInt();
608  const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
609  const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
610  ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
611  ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
612  clampOpMaxInt);
613  if (!opRangeIntRange.intersects(clampRangeIntRange))
614  return failure();
615 
616  // Run the transformation.
617  auto newMinVal = std::max(opMinInt, clampOpMinInt);
618  auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
619  newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
620  newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
621  } else {
622  // Check we have intersecting ranges.
623  const auto opMinInt = intMinValAttr.getInt();
624  const auto opMaxInt = intMaxValAttr.getInt();
625  const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
626  const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
627  ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
628  ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
629  clampOpMaxInt);
630  if (!opRangeIntRange.intersects(clampRangeIntRange))
631  return failure();
632 
633  // Run the transformation.
634  auto newMinVal = std::max(opMinInt, clampOpMinInt);
635  auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
636  newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
637  newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
638  }
639  }
640 
641  auto newMode = (opNanMode != clampNanMode)
642  ? tosa::NanPropagationMode::IGNORE
643  : opNanMode;
644 
645  auto newModeAttr =
646  NanPropagationModeAttr::get(rewriter.getContext(), newMode);
647 
648  rewriter.replaceOpWithNewOp<tosa::ClampOp>(
649  op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
650  newModeAttr);
651  return success();
652  }
653 };
654 
655 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
656  MLIRContext *context) {
657  results.add<ClampIsNoOp>(context);
658  results.add<ClampClampOptimization>(context);
659 }
660 
661 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
663 
664  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
665  PatternRewriter &rewriter) const override {
666  Value sliceInput = sliceOp.getInput1();
667  auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
668  if (!concatOp)
669  return rewriter.notifyMatchFailure(
670  sliceOp, "slice input must be concat operation");
671 
672  OperandRange inputs = concatOp.getInput1();
673  auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
674  if (!concatType || !concatType.hasStaticShape())
675  return rewriter.notifyMatchFailure(
676  sliceOp, "slice input must be a static ranked tensor");
677  int32_t axis = concatOp.getAxis();
678 
679  DenseElementsAttr startElems;
680  DenseElementsAttr sizeElems;
681 
682  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
683  return rewriter.notifyMatchFailure(
684  sliceOp, "start of slice must be a static ranked shape");
685 
686  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
687  return rewriter.notifyMatchFailure(
688  sliceOp, "size of slice must be a static ranked shape");
689 
690  llvm::SmallVector<int64_t> sliceStarts =
691  llvm::to_vector(startElems.getValues<int64_t>());
692  llvm::SmallVector<int64_t> sliceSizes =
693  llvm::to_vector(sizeElems.getValues<int64_t>());
694 
695  // Validate slice on the concatenated axis. Slicing along this
696  // axis should span only one of the inputs to the concatenate
697  // operation.
698  std::optional<Value> replaceWithSlice;
699  for (auto input : inputs) {
700  auto inputType = dyn_cast<RankedTensorType>(input.getType());
701  if (!inputType || !inputType.hasStaticShape())
702  return rewriter.notifyMatchFailure(
703  sliceOp, "concat input must be a static ranked tensor");
704 
705  if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
706  inputType.getDimSize(axis)) {
707  auto start_op =
708  getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
709  auto size_op =
710  getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
711  replaceWithSlice =
712  tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
713  input, start_op, size_op)
714  .getResult();
715  break;
716  }
717  sliceStarts[axis] -= inputType.getDimSize(axis);
718  }
719 
720  if (!replaceWithSlice)
721  return rewriter.notifyMatchFailure(
722  sliceOp, "corresponding concat input not found for slice");
723 
724  rewriter.replaceOp(sliceOp, replaceWithSlice.value());
725  return success();
726  }
727 };
728 
729 struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
731 
732  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
733  PatternRewriter &rewriter) const override {
734  Value sliceInput = sliceOp.getInput1();
735 
736  // Check if producer is a PadOp
737  auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
738  if (!padOp)
739  return rewriter.notifyMatchFailure(sliceOp,
740  "slice input must be a pad operation");
741 
742  // Check PadOp has a single consumer
743  if (!padOp->hasOneUse())
744  return rewriter.notifyMatchFailure(sliceOp,
745  "pad shall have a single consumer");
746 
747  // Check input is statically ranked
748  auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
749  auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
750  if (!inputTy || !padTy || !inputTy.hasRank())
751  return rewriter.notifyMatchFailure(sliceOp,
752  "slice input must be a ranked tensor");
753 
754  // Validate and extract tosa::PadOp padding
755  DenseIntElementsAttr paddingElems;
756  if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
757  return rewriter.notifyMatchFailure(
758  sliceOp,
759  "`padding` input specified on the tosa::PadOp must be constant.");
760  }
761  llvm::SmallVector<int64_t> padPaddings =
762  llvm::to_vector(paddingElems.getValues<int64_t>());
763 
764  // Extract slice parameters
765  DenseElementsAttr startElems;
766  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
767  return rewriter.notifyMatchFailure(
768  sliceOp, "start of slice must be a static ranked shape");
769  llvm::SmallVector<int64_t> sliceStarts =
770  llvm::to_vector(startElems.getValues<int64_t>());
771 
772  DenseElementsAttr sizeElems;
773  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
774  return rewriter.notifyMatchFailure(
775  sliceOp, "size of slice must be a static ranked shape");
776  llvm::SmallVector<int64_t> sliceSizes =
777  llvm::to_vector(sizeElems.getValues<int64_t>());
778 
779  // Check if dynamic dimensions are sliced
780  const int64_t rank = inputTy.getRank();
781  if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
782  const bool isDimDynamic = inputTy.isDynamicDim(i);
783  const bool isDimSliced =
784  (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
785 
786  return isDimDynamic && isDimSliced;
787  })) {
788  return rewriter.notifyMatchFailure(
789  sliceOp, "axis that are sliced shall be statically known.");
790  }
791 
792  // Update the parameters
793  llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
794  llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
795  llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
796  bool updated = false;
797 
798  for (int64_t i = 0; i < rank; ++i) {
799  const int64_t padLo = padPaddings[i * 2];
800  const int64_t padHi = padPaddings[i * 2 + 1];
801  const int64_t sliceStart = sliceStarts[i];
802  const int64_t sliceSize = sliceSizes[i];
803  const int64_t sliceEnd = sliceStart + sliceSize;
804 
805  // If dimension is dynamic pass-through
806  if (inputTy.isDynamicDim(i)) {
807  newPadPaddings[i * 2] = padLo;
808  newPadPaddings[i * 2 + 1] = padHi;
809  newSliceStarts[i] = sliceStart;
810  continue;
811  }
812 
813  // Handle static dimensions
814  const int64_t dimSize = inputTy.getShape()[i];
815  const int64_t dimTotal = padLo + dimSize + padHi;
816 
817  // Check slice within bounds
818  if (sliceStart < 0 || sliceEnd > dimTotal)
819  return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
820 
821  // Compute updated slice start parameter
822  const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
823  newSliceStarts[i] = newSliceStart;
824  updated |= newSliceStart != sliceStart;
825 
826  // Compute updated pad parameters
827  const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
828  const int64_t newPadHi =
829  std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
830  newPadPaddings[i * 2] = newPadLo;
831  newPadPaddings[i * 2 + 1] = newPadHi;
832  updated |= (newPadLo != padLo) || (newPadHi != padHi);
833 
834  // Calculate new pad output shape
835  newPadShape[i] =
836  newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
837  }
838 
839  // Check that we actually need to proceed with the rewrite
840  if (!updated)
841  return rewriter.notifyMatchFailure(
842  sliceOp, "terminate condition; nothing to rewrite");
843 
844  // Create a PadOp with updated padding
845  auto newPaddingsOp =
846  getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
847  auto newPadTy =
848  RankedTensorType::get(newPadShape, inputTy.getElementType());
849  auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
850  padOp.getInput1(), newPaddingsOp,
851  padOp.getPadConst());
852 
853  // Update SliceOp and point to new PadOp
854  auto newStartOp =
855  getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
856  rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
857  newPadOp.getResult(), newStartOp,
858  sliceOp.getSize());
859 
860  return success();
861  }
862 };
863 
864 // Update size operand of tosa.slice if size has dynamic dims but corresponding
865 // output dim is static
867  : public OpRewritePattern<tosa::SliceOp> {
869 
870  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
871  PatternRewriter &rewriter) const override {
872  ShapedType resultType = cast<ShapedType>(sliceOp.getType());
873 
874  ElementsAttr sizeElems;
875  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
876  return rewriter.notifyMatchFailure(
877  sliceOp, "size of slice must be a static ranked shape");
878  }
879 
880  llvm::SmallVector<int64_t> sliceSizes =
881  llvm::to_vector(sizeElems.getValues<int64_t>());
882 
883  bool replaceSliceSize{false};
884  // if size op has -1 indicating dynamic shape but corresponding dim on the
885  // output is statically known, update size to match with known output dim
886  // shape
887  for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
888  if (size == -1 && !resultType.isDynamicDim(index)) {
889  sliceSizes[index] = resultType.getDimSize(index);
890  replaceSliceSize = true;
891  }
892  }
893 
894  if (!replaceSliceSize) {
895  return rewriter.notifyMatchFailure(
896  sliceOp, "no dimension of size of slice is dynamic that resolves "
897  "to static output shape");
898  }
899 
900  auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
901  auto newSliceOp =
902  tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
903  sliceOp.getInput1(), sliceOp.getStart(), size_op);
904 
905  rewriter.replaceOp(sliceOp, newSliceOp.getResult());
906  return success();
907  }
908 };
909 
910 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
911  MLIRContext *context) {
914 }
915 
916 //===----------------------------------------------------------------------===//
917 // Operator Folders.
918 //===----------------------------------------------------------------------===//
919 
920 template <typename IntFolder, typename FloatFolder>
922  RankedTensorType returnTy) {
923  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
924  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
925  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
926  if (lETy != rETy)
927  return {};
928 
929  if (llvm::isa<IntegerType>(lETy)) {
930  APInt l = lhs.getSplatValue<APInt>();
931  APInt r = rhs.getSplatValue<APInt>();
932  auto result = IntFolder()(l, r);
933  return DenseElementsAttr::get(returnTy, result);
934  }
935 
936  if (llvm::isa<FloatType>(lETy)) {
937  APFloat l = lhs.getSplatValue<APFloat>();
938  APFloat r = rhs.getSplatValue<APFloat>();
939  auto result = FloatFolder()(l, r);
940  return DenseElementsAttr::get(returnTy, result);
941  }
942  }
943 
944  return {};
945 }
946 
947 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
948  if (llvm::isa<FloatType>(elemType))
949  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
950  if (llvm::isa<IntegerType>(elemType))
951  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
952  return false;
953 }
954 
955 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
956  if (llvm::isa<FloatType>(elemType))
957  return val && val.isSplat() &&
958  val.getSplatValue<APFloat>().isExactlyValue(1.0);
959  if (llvm::isa<IntegerType>(elemType)) {
960  const int64_t shifted = 1LL << shift;
961  return val && val.isSplat() &&
962  val.getSplatValue<APInt>().getSExtValue() == shifted;
963  }
964  return false;
965 }
966 
967 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
968  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
969  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
970  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
971  if (!lhsTy || !rhsTy || !resultTy)
972  return {};
973 
974  // Cannot create an ElementsAttr from non-int/float/index types
975  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
976  !rhsTy.getElementType().isIntOrIndexOrFloat())
977  return {};
978 
979  auto resultETy = resultTy.getElementType();
980  auto lhsAttr =
981  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
982  auto rhsAttr =
983  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
984 
985  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
986  return getInput1();
987  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
988  return getInput2();
989 
990  if (!lhsAttr || !rhsAttr)
991  return {};
992 
993  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
994  resultTy);
995 }
996 
997 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
998  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
999  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1000  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1001  !outputTy.hasStaticShape())
1002  return {};
1003 
1004  if (inputTy.getDimSize(getAxis()) == 1)
1005  return DenseElementsAttr::get(outputTy, 0);
1006 
1007  return {};
1008 }
1009 
1010 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1011  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1012  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1013  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1014  if (!lhsTy || !rhsTy || !resultTy)
1015  return {};
1016  if (lhsTy != rhsTy)
1017  return {};
1018 
1019  // IntDivOp inputs must be integer type, no need to check for quantized type
1020  auto resultETy = resultTy.getElementType();
1021  auto lhsAttr =
1022  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1023  auto rhsAttr =
1024  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1025  if (lhsAttr && lhsAttr.isSplat()) {
1026  if (llvm::isa<IntegerType>(resultETy) &&
1027  lhsAttr.getSplatValue<APInt>().isZero())
1028  return lhsAttr;
1029  }
1030 
1031  if (rhsAttr && rhsAttr.isSplat()) {
1032  if (llvm::isa<IntegerType>(resultETy) &&
1033  rhsAttr.getSplatValue<APInt>().isOne())
1034  return getInput1();
1035  }
1036 
1037  if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1038  llvm::isa<IntegerType>(resultETy)) {
1039  APInt l = lhsAttr.getSplatValue<APInt>();
1040  APInt r = rhsAttr.getSplatValue<APInt>();
1041  if (!r.isZero()) {
1042  APInt result = l.sdiv(r);
1043  return DenseElementsAttr::get(resultTy, result);
1044  }
1045  }
1046 
1047  return {};
1048 }
1049 
1050 namespace {
1051 // calculate lhs * rhs >> shift according to TOSA Spec
1052 // return nullopt if result is not in range of int32_t when shift > 0
1053 std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1054  unsigned bitwidth) {
1055  APInt result = lhs.sext(64) * rhs.sext(64);
1056 
1057  if (shift > 0) {
1058  auto round = APInt(64, 1) << (shift - 1);
1059  result += round;
1060  result.ashrInPlace(shift);
1061  // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
1062  if (!(result.getSExtValue() >= INT32_MIN &&
1063  result.getSExtValue() <= INT32_MAX)) {
1064  // REQUIRE failed
1065  return std::nullopt;
1066  }
1067  }
1068 
1069  return result.trunc(bitwidth);
1070 }
1071 
1072 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1073  RankedTensorType ty, int32_t shift) {
1074  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1075  if (llvm::isa<IntegerType>(ty.getElementType())) {
1076  APInt l = lhs.getSplatValue<APInt>();
1077  APInt r = rhs.getSplatValue<APInt>();
1078 
1079  if (shift == 0) {
1080  return DenseElementsAttr::get(ty, l * r);
1081  }
1082 
1083  auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1084  const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1085  if (!result)
1086  return {};
1087  return DenseElementsAttr::get(ty, result.value());
1088  }
1089 
1090  if (llvm::isa<FloatType>(ty.getElementType())) {
1091  APFloat l = lhs.getSplatValue<APFloat>();
1092  APFloat r = rhs.getSplatValue<APFloat>();
1093  APFloat result = l * r;
1094  return DenseElementsAttr::get(ty, result);
1095  }
1096  }
1097 
1098  return {};
1099 }
1100 } // namespace
1101 
1102 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1103  auto lhs = getInput1();
1104  auto rhs = getInput2();
1105  auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1106  auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1107  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1108  if (!lhsTy || !rhsTy || !resultTy)
1109  return {};
1110 
1111  auto resultETy = resultTy.getElementType();
1112  auto lhsAttr =
1113  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1114  auto rhsAttr =
1115  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1116 
1117  // Result right shift on i32_t data type only. For simplification, synthesize
1118  // a zero shift for other data type.
1119  int32_t shift = 0;
1120  if (resultETy.isInteger(32)) {
1121  ElementsAttr shift_elem;
1122  if (getShift().getImpl()) {
1123  if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1124  // cannot be folded when the shift value is unknown.
1125  return {};
1126  shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1127  }
1128  }
1129 
1130  if (rhsTy == resultTy) {
1131  if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1132  // constant values can only be resized if resulting type is static
1133  return lhsAttr.resizeSplat(resultTy);
1134  if (isSplatOne(resultETy, lhsAttr, shift))
1135  return rhs;
1136  }
1137  if (lhsTy == resultTy) {
1138  if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1139  return rhsAttr.resizeSplat(resultTy);
1140  if (isSplatOne(resultETy, rhsAttr, shift))
1141  return lhs;
1142  }
1143 
1144  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1145 }
1146 
1147 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1148  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1149  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1150  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1151  if (!lhsTy || !rhsTy || !resultTy)
1152  return {};
1153 
1154  // Cannot create an ElementsAttr from non-int/float/index types
1155  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1156  !rhsTy.getElementType().isIntOrIndexOrFloat())
1157  return {};
1158 
1159  auto resultETy = resultTy.getElementType();
1160  auto lhsAttr =
1161  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1162  auto rhsAttr =
1163  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1164 
1165  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1166  return getInput1();
1167 
1168  if (!lhsAttr || !rhsAttr)
1169  return {};
1170 
1171  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
1172  resultTy);
1173 }
1174 
1175 namespace {
1176 template <typename Cmp>
1177 struct ComparisonFold {
1178  ComparisonFold() = default;
1179  APInt operator()(const APInt &l, const APInt &r) {
1180  return APInt(1, Cmp()(l, r));
1181  }
1182 
1183  APInt operator()(const APFloat &l, const APFloat &r) {
1184  return APInt(1, Cmp()(l, r));
1185  }
1186 };
1187 
1188 struct APIntFoldGreater {
1189  APIntFoldGreater() = default;
1190  APInt operator()(const APInt &l, const APInt &r) {
1191  return APInt(1, l.sgt(r));
1192  }
1193 };
1194 
1195 struct APIntFoldGreaterEqual {
1196  APIntFoldGreaterEqual() = default;
1197  APInt operator()(const APInt &l, const APInt &r) {
1198  return APInt(1, l.sge(r));
1199  }
1200 };
1201 } // namespace
1202 
1203 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1204  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1205  auto lhsAttr =
1206  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1207  auto rhsAttr =
1208  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1209 
1210  if (!lhsAttr || !rhsAttr)
1211  return {};
1212 
1213  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
1214  lhsAttr, rhsAttr, resultTy);
1215 }
1216 
1217 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1218  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1219  auto lhsAttr =
1220  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1221  auto rhsAttr =
1222  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1223 
1224  if (!lhsAttr || !rhsAttr)
1225  return {};
1226 
1227  return binaryFolder<APIntFoldGreaterEqual,
1228  ComparisonFold<std::greater_equal<APFloat>>>(
1229  lhsAttr, rhsAttr, resultTy);
1230 }
1231 
1232 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1233  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1234  auto lhsAttr =
1235  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1236  auto rhsAttr =
1237  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1238  Value lhs = getInput1();
1239  Value rhs = getInput2();
1240  auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1241 
1242  // If we are comparing an integer value to itself it is always true. We can
1243  // not do this with float due to float values.
1244  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1245  resultTy.hasStaticShape() && lhs == rhs) {
1246  return DenseElementsAttr::get(resultTy, true);
1247  }
1248 
1249  if (!lhsAttr || !rhsAttr)
1250  return {};
1251 
1252  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
1253  ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1254  resultTy);
1255 }
1256 
1257 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1258  if (getInput().getType() == getType())
1259  return getInput();
1260 
1261  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1262  if (!operand)
1263  return {};
1264 
1265  auto inTy = llvm::cast<ShapedType>(getInput().getType());
1266  auto outTy = llvm::cast<ShapedType>(getType());
1267  auto inETy = inTy.getElementType();
1268  auto outETy = outTy.getElementType();
1269 
1270  if (operand.isSplat()) {
1271  if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1272  bool overflow;
1273  auto splatVal = operand.getSplatValue<APFloat>();
1274  auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1275  splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1276  &overflow);
1277  return SplatElementsAttr::get(outTy, splatVal);
1278  }
1279 
1280  if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1281  auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1282  APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1283  splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1284  llvm::RoundingMode::NearestTiesToEven);
1285  return SplatElementsAttr::get(outTy, splatVal);
1286  }
1287 
1288  if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1289  auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1290  auto intVal = APSInt(
1291  llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1292  auto floatVal = operand.getSplatValue<APFloat>();
1293  bool exact;
1294  floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1295  &exact);
1296  return SplatElementsAttr::get(outTy, intVal);
1297  }
1298 
1299  if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1300  const auto inIntType = llvm::cast<IntegerType>(inETy);
1301  auto unsignIn = inIntType.isUnsignedInteger();
1302  bool trunc =
1303  inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1304  auto intVal = operand.getSplatValue<APInt>();
1305  auto bitwidth = outETy.getIntOrFloatBitWidth();
1306 
1307  // i1 types are boolean in TOSA
1308  if (outETy.isInteger(1)) {
1309  intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1310  } else if (trunc) {
1311  intVal = intVal.trunc(bitwidth);
1312  } else if (unsignIn || inIntType.isInteger(1)) {
1313  intVal = intVal.zext(bitwidth);
1314  } else {
1315  intVal = intVal.sext(bitwidth);
1316  }
1317 
1318  return SplatElementsAttr::get(outTy, intVal);
1319  }
1320  }
1321 
1322  return {};
1323 }
1324 
1325 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1326 
1327 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1328 
1329 #define REDUCE_FOLDER(OP) \
1330  OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1331  ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1332  if (!inputTy.hasRank()) \
1333  return {}; \
1334  if (inputTy != getType()) \
1335  return {}; \
1336  if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1337  return getInput(); \
1338  return {}; \
1339  }
1340 
1341 REDUCE_FOLDER(ReduceAllOp)
1342 REDUCE_FOLDER(ReduceAnyOp)
1343 REDUCE_FOLDER(ReduceMaxOp)
1344 REDUCE_FOLDER(ReduceMinOp)
1345 REDUCE_FOLDER(ReduceProductOp)
1346 REDUCE_FOLDER(ReduceSumOp)
1347 #undef REDUCE_FOLDER
1348 
1349 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1350  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1351  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1352 
1353  if (!inputTy || !outputTy)
1354  return {};
1355 
1356  // Fold when the input and output types are the same. This is only safe when
1357  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
1358  // there may still be a productive reshape.
1359  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1360  return getInput1();
1361 
1362  // reshape(reshape(x)) -> reshape(x)
1363  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1364  getInput1().getDefiningOp())) {
1365  getInput1Mutable().assign(reshapeOp.getInput1());
1366  return getResult();
1367  }
1368 
1369  // Cannot create an ElementsAttr from non-int/float/index types
1370  if (!inputTy.getElementType().isIntOrIndexOrFloat())
1371  return {};
1372 
1373  // reshape(const(x)) -> const(reshape-attr(x))
1374  if (auto operand =
1375  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1376  // Constants must have static shape.
1377  if (!outputTy.hasStaticShape())
1378  return {};
1379 
1380  // Okay to duplicate splat constants.
1381  if (operand.isSplat())
1382  return SplatElementsAttr::get(outputTy,
1383  operand.getSplatValue<Attribute>());
1384 
1385  // Don't duplicate other constants.
1386  if (!getInput1().hasOneUse())
1387  return {};
1388 
1389  llvm::SmallVector<int64_t> shapeVec;
1390  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1391  return {};
1392 
1393  return operand.reshape(
1394  llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1395  }
1396 
1397  return {};
1398 }
1399 
1400 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1401  // If the pad is all zeros we can fold this operation away.
1402  if (adaptor.getPadding() && getInput1().getType() == getType()) {
1403  auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1404  if (densePad && densePad.isSplat() &&
1405  densePad.getSplatValue<APInt>().isZero()) {
1406  return getInput1();
1407  }
1408  }
1409 
1410  return {};
1411 }
1412 
1413 // Fold away cases where a tosa.resize operation returns a copy
1414 // of the input image.
1415 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1416  auto scaleAttr =
1417  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1418  auto offsetAttr =
1419  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1420  auto borderAttr =
1421  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1422  if (!scaleAttr || !offsetAttr || !borderAttr) {
1423  return {};
1424  }
1425 
1426  auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1427  auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1428  auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1429  if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1430  return {};
1431  }
1432 
1433  // Check unit scaling.
1434  if (scale[0] != scale[1] || scale[2] != scale[3]) {
1435  return {};
1436  }
1437 
1438  // There should be no offset.
1439  if (offset[0] != 0 || offset[1] != 0) {
1440  return {};
1441  }
1442 
1443  // There should be no border.
1444  if (border[0] != 0 || border[1] != 0) {
1445  return {};
1446  }
1447 
1448  auto input = getInput();
1449  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1450  auto resultTy = llvm::cast<RankedTensorType>(getType());
1451  if (inputTy != resultTy)
1452  return {};
1453 
1454  return input;
1455 }
1456 
1457 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1458  auto operand = getInput1();
1459  auto operandTy = llvm::cast<ShapedType>(operand.getType());
1460  auto axis = getAxis();
1461  auto operandAttr =
1462  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1463  if (operandAttr)
1464  return operandAttr;
1465 
1466  // If the dim-length is 1, tosa.reverse is a no-op.
1467  if (operandTy.hasRank() &&
1468  (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1469  return operand;
1470 
1471  return {};
1472 }
1473 
1474 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1475  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1476  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1477 
1478  if (!inputTy || !outputTy)
1479  return {};
1480 
1481  if (inputTy == outputTy && inputTy.hasStaticShape())
1482  return getInput1();
1483 
1484  if (!adaptor.getInput1())
1485  return {};
1486 
1487  // Cannot create an ElementsAttr from non-int/float/index types
1488  if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1489  !outputTy.getElementType().isIntOrIndexOrFloat())
1490  return {};
1491 
1492  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1493  if (operand.isSplat() && outputTy.hasStaticShape()) {
1494  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1495  }
1496 
1497  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1498  outputTy.getNumElements() == 1) {
1499  DenseElementsAttr startElems;
1500  if (!matchPattern(getStart(), m_Constant(&startElems)))
1501  return {};
1502 
1503  llvm::SmallVector<uint64_t> indices =
1504  llvm::to_vector(startElems.getValues<uint64_t>());
1505  auto value = operand.getValues<Attribute>()[indices];
1506  return SplatElementsAttr::get(outputTy, value);
1507  }
1508 
1509  return {};
1510 }
1511 
1512 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1513  if (getOnTrue() == getOnFalse())
1514  return getOnTrue();
1515 
1516  auto predicate =
1517  llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1518  if (!predicate)
1519  return {};
1520 
1521  if (!predicate.isSplat())
1522  return {};
1523  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1524  : getOnFalse();
1525 }
1526 
1527 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1528  if (getInput1().getType() == getType()) {
1529  if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1530  adaptor.getMultiples())) {
1531  if (multiples.isSplat() &&
1532  multiples.getSplatValue<APInt>().getSExtValue() == 1)
1533  return getInput1();
1534  if (auto int_array_attr =
1535  llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1536  if (llvm::all_of(int_array_attr.getValues<APInt>(),
1537  [](APInt v) { return v.getSExtValue() == 1; }))
1538  return getInput1();
1539  }
1540  }
1541  }
1542  return {};
1543 }
1544 
1545 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1546  auto resultTy = llvm::cast<ShapedType>(getType());
1547 
1548  // Transposing splat values just means reshaping.
1549  if (auto input =
1550  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1551  if (input.isSplat() && resultTy.hasStaticShape() &&
1552  input.getType().getElementType() == resultTy.getElementType())
1553  return input.reshape(resultTy);
1554  }
1555 
1556  // Transpose is not the identity transpose.
1557  const llvm::ArrayRef<int32_t> perms = getPerms();
1558 
1559  if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1560  return {};
1561 
1562  return getInput1();
1563 }
1564 
1565 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
1566  auto input = getInput1();
1567  // Element-wise log(exp(x)) = x
1568  if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
1569  return op.getInput1();
1570  }
1571 
1572  return {};
1573 }
1574 
1575 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
1576  auto input = getInput1();
1577  // Element-wise exp(log(x)) = x
1578  if (auto op = input.getDefiningOp<tosa::LogOp>()) {
1579  return op.getInput1();
1580  }
1581 
1582  return {};
1583 }
1584 
1585 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1586  // Element-wise negate(negate(x)) = x
1587  // iff all zero points are constant 0
1588  auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1589  if (!definingOp) {
1590  // defining op of input1 is not a negate, cannot fold
1591  return {};
1592  }
1593 
1594  if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1595  failed(maybeIZp) || *maybeIZp != 0) {
1596  // input1 zero point is not constant 0, cannot fold
1597  return {};
1598  }
1599  if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1600  failed(maybeOZp) || *maybeOZp != 0) {
1601  // output zero point is not constant 0, cannot fold
1602  return {};
1603  }
1604  if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1605  failed(maybeIZp) || *maybeIZp != 0) {
1606  // definingOp's input1 zero point is not constant 0, cannot fold
1607  return {};
1608  }
1609  if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1610  failed(maybeOZp) || *maybeOZp != 0) {
1611  // definingOp's output zero point is not constant 0, cannot fold
1612  return {};
1613  }
1614 
1615  return definingOp.getInput1();
1616 }
1617 
1618 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1619  auto input = getInput1();
1620  // Element-wise abs(abs(x)) = abs(x)
1621  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1622  return input;
1623  }
1624 
1625  return {};
1626 }
1627 
1628 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1629  // Fold consecutive concats on the same axis into a single op.
1630  // Keep track of the operands so we are able to construct a new concat
1631  // later. Conservatively assume that we double the number of operands when
1632  // folding
1633  SmallVector<Value, 8> concatOperands;
1634  concatOperands.reserve(2 * getNumOperands());
1635 
1636  // Find all operands that are foldable concats
1637  bool foundFoldableConcat = false;
1638  for (Value operand : getOperands()) {
1639  concatOperands.emplace_back(operand);
1640 
1641  auto producer = operand.getDefiningOp<ConcatOp>();
1642  if (!producer)
1643  continue;
1644 
1645  // Not foldable if axes are not the same
1646  if (getAxis() != producer.getAxis())
1647  continue;
1648 
1649  // Replace the original operand with all incoming operands
1650  foundFoldableConcat = true;
1651  concatOperands.pop_back();
1652  llvm::append_range(concatOperands, producer->getOperands());
1653  }
1654 
1655  if (!foundFoldableConcat)
1656  return {};
1657 
1658  getOperation()->setOperands(concatOperands);
1659  return getResult();
1660 }
1661 
1662 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1663  auto input = adaptor.getInput1();
1664 
1665  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1666  // Fold splat inputs only.
1667  if (!inputAttr || !inputAttr.isSplat())
1668  return {};
1669 
1670  auto shapeType = llvm::cast<ShapedType>(getType());
1671  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1672  auto floatVal = inputAttr.getSplatValue<APFloat>();
1673  return DenseElementsAttr::get(shapeType,
1674  ReciprocalOp::calcOneElement(floatVal));
1675  }
1676 
1677  return {};
1678 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Attributes are known-constant values of operations.
Definition: Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:162
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:166
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:253
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
MLIRContext * getContext() const
Definition: Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319