MLIR  21.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // TOSA canonicalization patterns and folders.
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <functional>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 //===----------------------------------------------------------------------===//
39 // Operator Canonicalizers.
40 //===----------------------------------------------------------------------===//
41 
42 //===----------------------------------------------------------------------===//
43 // Tensor Data Engine Operators.
44 //===----------------------------------------------------------------------===//
45 
46 // Check that the zero point of the tensor and padding operations are aligned.
48  // Check that padConst is a constant value and a scalar tensor
49  DenseElementsAttr padConstAttr;
50  if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
51  (padConstAttr.size() != 1)) {
52  return false;
53  }
54 
55  // Check that floating point pad is zero
56  if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57  float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
58  return padConstVal == 0.0f;
59  }
60 
61  // Check that the zp and padConst align for the integer (quantized) case
62  if (auto padConstIntAttr =
63  mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
64  DenseIntElementsAttr zpAttr;
65  // Check that zp is a constant value and a scalar tensor
66  if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
67  return false;
68  }
69 
70  // Check equality
71  int64_t zpVal = (*zpAttr.begin()).getSExtValue();
72  int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
73  return zpVal == padConstVal;
74  }
75 
76  // Bail-out on unsupported type
77  return false;
78 }
79 
80 namespace {
81 template <typename OpTy>
82 struct PoolPadFoldAdaptor;
83 
84 template <>
85 struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
86  using OpTy = tosa::AvgPool2dOp;
87  static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
88  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
89  if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
90  newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
91  return false;
92  return true;
93  }
94  static bool checkPadConstCompliance(OpTy op, Value padConst) {
95  return checkMatchingPadConstAndZp(padConst, op.getInputZp());
96  }
97  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
98  Value padInput, ArrayRef<int64_t> newPad) {
99  rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
100  op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
101  op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
102  op.getAccType());
103  }
104 };
105 
106 template <>
107 struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
108  using OpTy = tosa::MaxPool2dOp;
109  static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
110  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
111  if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
112  newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
113  return false;
114  return true;
115  }
116  static bool checkPadConstCompliance(OpTy, Value padConst) {
117  // Check that padConst is a constant value and a scalar tensor
118  DenseElementsAttr padConstAttr;
119  if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
120  padConstAttr.size() != 1) {
121  return false;
122  }
123 
124  // Pad needs to be in the minimum value to be able to merge
125  if (auto padConstFpAttr =
126  mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127  const APFloat padConstVal = *padConstFpAttr.begin();
128  const APFloat lowestVal =
129  APFloat::getLargest(padConstVal.getSemantics(), true);
130  return padConstVal == lowestVal;
131  } else if (auto padConstIntAttr =
132  mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
133  const APInt padConstVal = *padConstIntAttr.begin();
134  const unsigned int bitWidth = padConstVal.getBitWidth();
135  const APInt lowestVal =
136  padConstIntAttr.getElementType().isUnsignedInteger()
137  ? APInt::getZero(bitWidth)
138  : APInt::getSignedMinValue(bitWidth);
139  return padConstVal == lowestVal;
140  }
141 
142  // Bail-out on unsupported type
143  return false;
144  }
145  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
146  Value padInput, ArrayRef<int64_t> newPad) {
147  rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
148  op, op.getType(), padInput, op.getKernel(), op.getStride(),
149  rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
150  }
151 };
152 
153 template <typename OpTy>
154 struct ConvPadFoldAdaptor {
155  static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
156  return true;
157  }
158  static bool checkPadConstCompliance(OpTy op, Value padConst) {
159  return checkMatchingPadConstAndZp(padConst, op.getInputZp());
160  }
161  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
162  Value padInput, ArrayRef<int64_t> newPad) {
163  rewriter.replaceOpWithNewOp<OpTy>(
164  op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
165  op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
166  op.getDilationAttr(), op.getAccType(), op.getLocalBound());
167  }
168 };
169 
170 // Pattern attempts to fold a `tosa.pad` operator to a following tensor
171 // operation like `tosa.conv2d` by merging the padding associated with the
172 // pad operator directly to the implicit padding of the tensor operation.
173 // This helps eliminate the explicit padding operator if unused.
174 template <typename OpTy, typename AdaptorTy>
175 struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
177 
178  LogicalResult matchAndRewrite(OpTy tensorOp,
179  PatternRewriter &rewriter) const override {
180  // Check producer is a tosa::PadOp
181  auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
182  if (!padOp)
183  return rewriter.notifyMatchFailure(tensorOp,
184  "Producer must be a tosa::PadOp.");
185 
186  // Validate that tensor operation has sane padding
187  const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
188  if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
189  return rewriter.notifyMatchFailure(
190  tensorOp, "Tensor operation padding shall have 4 elements.");
191 
192  // Validate tosa::PadOp padding
193  DenseIntElementsAttr padOpPadding;
194  if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
195  return rewriter.notifyMatchFailure(
196  tensorOp,
197  "The `padding` input specified on the tosa::PadOp must be constant.");
198  }
199  // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
200  // C_after
201  if (padOpPadding.size() != 8)
202  return rewriter.notifyMatchFailure(tensorOp,
203  "Pad padding should have 8 elements.");
204  int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
205  int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
206  int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
207  int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
208  int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
209  int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
210  int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
211  int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
212 
213  if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
214  return rewriter.notifyMatchFailure(
215  tensorOp, "Folding padding in N or C dimensions is not supported.");
216 
217  // Fold padding from Pad into the tensor operation
218  // 4 elements - pad_top, pad_bottom, pad_left, pad_right
219  SmallVector<int64_t> foldedPad(tensorOpPad.size());
220  foldedPad[0] = padHBefore + tensorOpPad[0];
221  foldedPad[1] = padHAfter + tensorOpPad[1];
222  foldedPad[2] = padWBefore + tensorOpPad[2];
223  foldedPad[3] = padWAfter + tensorOpPad[3];
224 
225  // Check kernel related restrictions
226  if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
227  return rewriter.notifyMatchFailure(
228  tensorOp, "Padding size not aligned with kernel restrictions.");
229  }
230 
231  // Check padding constant restrictions
232  if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
233  return rewriter.notifyMatchFailure(
234  tensorOp,
235  "Padding constant is not aligned with operator zero-point.");
236  }
237 
238  // Check that padding doesn't grow more than 8K level (8192) for now
239  if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
240  return rewriter.notifyMatchFailure(
241  tensorOp, "Padding size more than the 8K level limit.");
242  }
243 
244  // Create operator
245  AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
246  foldedPad);
247 
248  return success();
249  }
250 };
251 } // namespace
252 
253 void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
254  MLIRContext *context) {
255  results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
256  PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
257  context);
258 }
259 
260 void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
261  MLIRContext *context) {
262  results.add<
263  FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
264  context);
265 }
266 
267 void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
268  MLIRContext *context) {
269  results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
270  ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
271  context);
272 }
273 
274 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
276 
277  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
278  PatternRewriter &rewriter) const override {
279  Value input = op.getInput();
280  Value output = op.getOutput();
281  ShapedType inputType = llvm::cast<ShapedType>(input.getType());
282  ShapedType outputType = llvm::cast<ShapedType>(output.getType());
283 
284  if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
285  return failure();
286  }
287 
288  // If the output and input shapes are 1x1, then this is a no op.
289  ArrayRef<int64_t> outputShape = outputType.getShape();
290  if (outputShape[1] != 1 || outputShape[2] != 1) {
291  return failure();
292  }
293 
294  ArrayRef<int64_t> inputShape = inputType.getShape();
295  if (inputShape[1] != 1 || inputShape[2] != 1) {
296  return failure();
297  }
298 
299  rewriter.replaceOp(op, input);
300  return success();
301  }
302 };
303 
304 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
305  MLIRContext *context) {
306  results.add<MaxPool2dIsNoOp,
307  FoldPadToTensorOp<tosa::MaxPool2dOp,
308  PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
309  context);
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // Data Layout / Memory Reinterpretation.
314 //===----------------------------------------------------------------------===//
315 
316 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
318 
319  LogicalResult matchAndRewrite(tosa::ConcatOp op,
320  PatternRewriter &rewriter) const override {
321  if (op.getInput1().size() != 1)
322  return failure();
323  if (op.getInput1().front().getType() != op.getType()) {
324  rewriter
325  .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
326  op.getInput1().front())
327  .getResult();
328  return success();
329  }
330 
331  rewriter.replaceOp(op, op.getInput1().front());
332  return success();
333  }
334 };
335 
336 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
337  MLIRContext *context) {
338  results.add<ConcatOptimization>(context);
339 }
340 
341 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
342  auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
343  if (!notOp)
344  return failure();
345  rewriter.modifyOpInPlace(op, [&]() {
346  op.getOperation()->setOperands(
347  {notOp.getInput1(), op.getInput3(), op.getInput2()});
348  });
349  return success();
350 }
351 
353  : public OpRewritePattern<tosa::TransposeOp> {
355 
356  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
357  PatternRewriter &rewriter) const override {
358  // Input is also TransposeOp - transpose(transpose(A)).
359  auto innerTranspose =
360  transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
361  if (!innerTranspose)
362  return rewriter.notifyMatchFailure(transposeOp,
363  "input must be transpose operation");
364 
365  const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
366  const llvm::ArrayRef<int32_t> innerTransposePerms =
367  innerTranspose.getPerms();
368 
369  if (transposePerms.size() != innerTransposePerms.size())
370  return rewriter.notifyMatchFailure(
371  transposeOp,
372  "transpose and inner transpose perms sizes must be equal");
373  if (transposePerms.empty())
374  return rewriter.notifyMatchFailure(
375  transposeOp, "transpose perms sizes must be positive");
376 
377  // Consolidate transposes into one transpose.
378  SmallVector<int32_t> perms(transposePerms.size());
379  for (int i = 0, s = transposePerms.size(); i < s; ++i)
380  perms[i] = innerTransposePerms[transposePerms[i]];
381 
382  rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
383  transposeOp, transposeOp.getResult().getType(),
384  innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
385 
386  return success();
387  }
388 };
389 
390 // Determines the case when tosa.transpose is a tosa.reshape operation.
391 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
393 
394  LogicalResult matchAndRewrite(tosa::TransposeOp op,
395  PatternRewriter &rewriter) const override {
396  if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
397  return rewriter.notifyMatchFailure(
398  op, "Src is from transpose, can compose transposes");
399 
400  Value result = op.getResult();
401  for (Operation *subop : result.getUsers()) {
402  if (isa_and_nonnull<tosa::TransposeOp>(subop))
403  return rewriter.notifyMatchFailure(
404  op, "Dest is used by transpose, can compose transposes");
405  }
406 
407  auto input = op.getInput1();
408  auto inputTy = llvm::cast<ShapedType>(input.getType());
409  if (!inputTy.hasRank())
410  return rewriter.notifyMatchFailure(op, "Unranked input.");
411 
412  int64_t numDynDims = 0;
413  for (int i = 0; i < inputTy.getRank(); ++i)
414  if (inputTy.isDynamicDim(i))
415  numDynDims++;
416 
417  if (numDynDims > 1)
418  return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
419 
420  const llvm::ArrayRef<int32_t> permValues = op.getPerms();
421 
422  SmallVector<int64_t> nonZeroPerms;
423  nonZeroPerms.reserve(permValues.size());
424  for (auto idx : permValues) {
425  auto sz = inputTy.getDimSize(idx);
426  if (sz != 1)
427  nonZeroPerms.push_back(idx);
428  }
429 
430  for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
431  if (nonZeroPerms[i - 1] > nonZeroPerms[i])
432  return rewriter.notifyMatchFailure(op,
433  "Transpose changes memory layout.");
434 
435  SmallVector<int64_t> newShape;
436  newShape.reserve(inputTy.getRank());
437  for (int i = 0, s = inputTy.getRank(); i < s; ++i)
438  newShape.push_back(inputTy.getDimSize(permValues[i]));
439 
440  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
441  op, op.getType(), op.getInput1(),
442  getTosaConstShape(rewriter, op.getLoc(), newShape));
443  return success();
444  }
445 };
446 
447 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
448  MLIRContext *context) {
450 }
451 
452 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
454 
455  LogicalResult matchAndRewrite(tosa::ClampOp op,
456  PatternRewriter &rewriter) const override {
457  Value input = op.getInput();
458  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
459  auto inputElementType = inputType.getElementType();
460 
461  if (!inputType.hasStaticShape()) {
462  return failure();
463  }
464 
465  if (isa<FloatType>(inputElementType)) {
466  // Unlike integer types, floating point types can represent infinity.
467  auto minClamp =
468  llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
469  auto maxClamp =
470  llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
471  bool isMin = minClamp.isNegInfinity();
472  bool isMax = maxClamp.isInfinity();
473 
474  if (isMin && isMax) {
475  rewriter.replaceOp(op, input);
476  return success();
477  }
478  return failure();
479  }
480 
481  if (inputElementType.isUnsignedInteger()) {
482  int64_t minClamp =
483  llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
484  int64_t maxClamp =
485  llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
486 
487  int64_t intMin =
488  APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
489  .getZExtValue();
490  int64_t intMax =
491  APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
492  .getZExtValue();
493 
494  if (minClamp <= intMin && maxClamp >= intMax) {
495  rewriter.replaceOp(op, input);
496  return success();
497  }
498  return failure();
499  }
500 
501  if (llvm::isa<IntegerType>(inputElementType)) {
502  int64_t minClamp =
503  llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
504  int64_t maxClamp =
505  llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
506 
507  int64_t intMin =
508  APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
509  .getSExtValue();
510  int64_t intMax =
511  APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
512  .getSExtValue();
513 
514  if (minClamp <= intMin && maxClamp >= intMax) {
515  rewriter.replaceOp(op, input);
516  return success();
517  }
518  return failure();
519  }
520 
521  return failure();
522  }
523 };
524 
525 // Attempts the following transformation:
526 //
527 // For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
528 // tensor X the following identity holds:
529 //
530 // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
531 //
532 // subject to the following valid NaN propagation semantics:
533 // --------------------------------------------
534 // | OUTER CLAMP | INNER CLAMP | RESULT MODE |
535 // |-------------|--------------|-------------|
536 // | PROPAGATE | PROPAGATE | PROPAGATE |
537 // | PROPAGATE | IGNORE | IGNORE |
538 // | IGNORE | PROPAGATE | INVALID |
539 // | IGNORE | IGNORE | IGNORE |
540 // |------------------------------------------|
541 
542 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
544 
545  // Helper structure to describe the range of a clamp operation.
546  template <typename T>
547  struct ClampRange {
548  ClampRange(const T &start, const T &end) : start(start), end(end) {}
549  T start;
550  T end;
551 
552  // Helper function to determine if two Clamp ranges intersect.
553  bool intersects(const ClampRange<T> &otherRange) {
554  return start < otherRange.end && otherRange.start < end;
555  }
556  };
557 
558  LogicalResult matchAndRewrite(tosa::ClampOp op,
559  PatternRewriter &rewriter) const override {
560  Value input = op.getInput();
561 
562  // Check the input to the CLAMP op is itself a CLAMP.
563  auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
564  if (!clampOp)
565  return failure();
566 
567  // Check we have a valid NaN propagation combination.
568  const auto opNanMode = op.getNanMode();
569  const auto clampNanMode = clampOp.getNanMode();
570  if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
571  return failure();
572 
573  auto maxValAttr = op.getMaxValAttr();
574  auto minValAttr = op.getMinValAttr();
575  auto clampOpMaxValAttr = clampOp.getMaxValAttr();
576  auto clampOpMinValAttr = clampOp.getMinValAttr();
577 
578  auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
579  if (auto quantType =
580  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
581  inputEType = quantType.getStorageType();
582  }
583 
584  Attribute newMinValAttr, newMaxValAttr;
585  if (mlir::isa<FloatType>(inputEType)) {
586  auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
587  auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
588  auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
589  auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
590 
591  // Check we have intersecting ranges.
592  const auto opMinFloat = floatMinValAttr.getValue();
593  const auto opMaxFloat = floatMaxValAttr.getValue();
594  const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
595  const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
596  ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
597  ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
598  clampOpMaxFloat);
599  if (!opRangeFloatRange.intersects(clampRangeFloatRange))
600  return failure();
601 
602  // Run the transformation.
603  auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
604  auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
605  newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
606  newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
607  } else {
608  assert(mlir::isa<IntegerType>(inputEType));
609  auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
610  auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
611  auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
612  auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
613 
614  if (inputEType.isUnsignedInteger()) {
615  // Check we have intersecting ranges.
616  const auto opMinInt = intMinValAttr.getUInt();
617  const auto opMaxInt = intMaxValAttr.getUInt();
618  const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
619  const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
620  ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
621  ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
622  clampOpMaxInt);
623  if (!opRangeIntRange.intersects(clampRangeIntRange))
624  return failure();
625 
626  // Run the transformation.
627  auto newMinVal = std::max(opMinInt, clampOpMinInt);
628  auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
629  newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
630  newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
631  } else {
632  // Check we have intersecting ranges.
633  const auto opMinInt = intMinValAttr.getInt();
634  const auto opMaxInt = intMaxValAttr.getInt();
635  const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
636  const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
637  ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
638  ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
639  clampOpMaxInt);
640  if (!opRangeIntRange.intersects(clampRangeIntRange))
641  return failure();
642 
643  // Run the transformation.
644  auto newMinVal = std::max(opMinInt, clampOpMinInt);
645  auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
646  newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
647  newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
648  }
649  }
650 
651  rewriter.replaceOpWithNewOp<tosa::ClampOp>(
652  op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
653  rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
654  : opNanMode));
655  return success();
656  }
657 };
658 
659 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
660  MLIRContext *context) {
661  results.add<ClampIsNoOp>(context);
662  results.add<ClampClampOptimization>(context);
663 }
664 
665 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
667 
668  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
669  PatternRewriter &rewriter) const override {
670  Value sliceInput = sliceOp.getInput1();
671  auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
672  if (!concatOp)
673  return rewriter.notifyMatchFailure(
674  sliceOp, "slice input must be concat operation");
675 
676  OperandRange inputs = concatOp.getInput1();
677  auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
678  if (!concatType || !concatType.hasStaticShape())
679  return rewriter.notifyMatchFailure(
680  sliceOp, "slice input must be a static ranked tensor");
681  int32_t axis = concatOp.getAxis();
682 
683  DenseElementsAttr startElems;
684  DenseElementsAttr sizeElems;
685 
686  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
687  return rewriter.notifyMatchFailure(
688  sliceOp, "start of slice must be a static ranked shape");
689 
690  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
691  return rewriter.notifyMatchFailure(
692  sliceOp, "size of slice must be a static ranked shape");
693 
694  llvm::SmallVector<int64_t> sliceStarts =
695  llvm::to_vector(startElems.getValues<int64_t>());
696  llvm::SmallVector<int64_t> sliceSizes =
697  llvm::to_vector(sizeElems.getValues<int64_t>());
698 
699  // Validate slice on the concatenated axis. Slicing along this
700  // axis should span only one of the inputs to the concatenate
701  // operation.
702  std::optional<Value> replaceWithSlice;
703  for (auto input : inputs) {
704  auto inputType = dyn_cast<RankedTensorType>(input.getType());
705  if (!inputType || !inputType.hasStaticShape())
706  return rewriter.notifyMatchFailure(
707  sliceOp, "concat input must be a static ranked tensor");
708 
709  if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
710  inputType.getDimSize(axis)) {
711  auto start_op =
712  getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
713  auto size_op =
714  getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
715  replaceWithSlice =
716  rewriter
717  .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
718  input, start_op, size_op)
719  .getResult();
720  break;
721  }
722  sliceStarts[axis] -= inputType.getDimSize(axis);
723  }
724 
725  if (!replaceWithSlice)
726  return rewriter.notifyMatchFailure(
727  sliceOp, "corresponding concat input not found for slice");
728 
729  rewriter.replaceOp(sliceOp, replaceWithSlice.value());
730  return success();
731  }
732 };
733 
734 struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
736 
737  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
738  PatternRewriter &rewriter) const override {
739  Value sliceInput = sliceOp.getInput1();
740 
741  // Check if producer is a PadOp
742  auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
743  if (!padOp)
744  return rewriter.notifyMatchFailure(sliceOp,
745  "slice input must be a pad operation");
746 
747  // Check PadOp has a single consumer
748  if (!padOp->hasOneUse())
749  return rewriter.notifyMatchFailure(sliceOp,
750  "pad shall have a single consumer");
751 
752  // Check input is statically ranked
753  auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
754  auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
755  if (!inputTy || !padTy || !inputTy.hasRank())
756  return rewriter.notifyMatchFailure(sliceOp,
757  "slice input must be a ranked tensor");
758 
759  // Validate and extract tosa::PadOp padding
760  DenseIntElementsAttr paddingElems;
761  if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
762  return rewriter.notifyMatchFailure(
763  sliceOp,
764  "`padding` input specified on the tosa::PadOp must be constant.");
765  }
766  llvm::SmallVector<int64_t> padPaddings =
767  llvm::to_vector(paddingElems.getValues<int64_t>());
768 
769  // Extract slice parameters
770  DenseElementsAttr startElems;
771  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
772  return rewriter.notifyMatchFailure(
773  sliceOp, "start of slice must be a static ranked shape");
774  llvm::SmallVector<int64_t> sliceStarts =
775  llvm::to_vector(startElems.getValues<int64_t>());
776 
777  DenseElementsAttr sizeElems;
778  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
779  return rewriter.notifyMatchFailure(
780  sliceOp, "size of slice must be a static ranked shape");
781  llvm::SmallVector<int64_t> sliceSizes =
782  llvm::to_vector(sizeElems.getValues<int64_t>());
783 
784  // Check if dynamic dimensions are sliced
785  const int64_t rank = inputTy.getRank();
786  if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
787  const bool isDimDynamic = inputTy.isDynamicDim(i);
788  const bool isDimSliced =
789  (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
790 
791  return isDimDynamic && isDimSliced;
792  })) {
793  return rewriter.notifyMatchFailure(
794  sliceOp, "axis that are sliced shall be statically known.");
795  }
796 
797  // Update the parameters
798  llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
799  llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
800  llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
801  bool updated = false;
802 
803  for (int64_t i = 0; i < rank; ++i) {
804  const int64_t padLo = padPaddings[i * 2];
805  const int64_t padHi = padPaddings[i * 2 + 1];
806  const int64_t sliceStart = sliceStarts[i];
807  const int64_t sliceSize = sliceSizes[i];
808  const int64_t sliceEnd = sliceStart + sliceSize;
809 
810  // If dimension is dynamic pass-through
811  if (inputTy.isDynamicDim(i)) {
812  newPadPaddings[i * 2] = padLo;
813  newPadPaddings[i * 2 + 1] = padHi;
814  newSliceStarts[i] = sliceStart;
815  continue;
816  }
817 
818  // Handle static dimensions
819  const int64_t dimSize = inputTy.getShape()[i];
820  const int64_t dimTotal = padLo + dimSize + padHi;
821 
822  // Check slice within bounds
823  if (sliceStart < 0 || sliceEnd > dimTotal)
824  return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
825 
826  // Compute updated slice start parameter
827  const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
828  newSliceStarts[i] = newSliceStart;
829  updated |= newSliceStart != sliceStart;
830 
831  // Compute updated pad parameters
832  const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
833  const int64_t newPadHi =
834  std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
835  newPadPaddings[i * 2] = newPadLo;
836  newPadPaddings[i * 2 + 1] = newPadHi;
837  updated |= (newPadLo != padLo) || (newPadHi != padHi);
838 
839  // Calculate new pad output shape
840  newPadShape[i] =
841  newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
842  }
843 
844  // Check that we actually need to proceed with the rewrite
845  if (!updated)
846  return rewriter.notifyMatchFailure(
847  sliceOp, "terminate condition; nothing to rewrite");
848 
849  // Create a PadOp with updated padding
850  auto newPaddingsOp =
851  getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
852  auto newPadTy =
853  RankedTensorType::get(newPadShape, inputTy.getElementType());
854  auto newPadOp = rewriter.create<tosa::PadOp>(
855  padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
856  padOp.getPadConst());
857 
858  // Update SliceOp and point to new PadOp
859  auto newStartOp =
860  getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
861  rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
862  newPadOp.getResult(), newStartOp,
863  sliceOp.getSize());
864 
865  return success();
866  }
867 };
868 
869 // Update size operand of tosa.slice if size has dynamic dims but corresponding
870 // output dim is static
872  : public OpRewritePattern<tosa::SliceOp> {
874 
875  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
876  PatternRewriter &rewriter) const override {
877  ShapedType resultType = cast<ShapedType>(sliceOp.getType());
878 
879  ElementsAttr sizeElems;
880  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
881  return rewriter.notifyMatchFailure(
882  sliceOp, "size of slice must be a static ranked shape");
883  }
884 
885  llvm::SmallVector<int64_t> sliceSizes =
886  llvm::to_vector(sizeElems.getValues<int64_t>());
887 
888  bool replaceSliceSize{false};
889  // if size op has -1 indicating dynamic shape but corresponding dim on the
890  // output is statically known, update size to match with known output dim
891  // shape
892  for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
893  if (size == -1 && !resultType.isDynamicDim(index)) {
894  sliceSizes[index] = resultType.getDimSize(index);
895  replaceSliceSize = true;
896  }
897  }
898 
899  if (!replaceSliceSize) {
900  return rewriter.notifyMatchFailure(
901  sliceOp, "no dimension of size of slice is dynamic that resolves "
902  "to static output shape");
903  }
904 
905  auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
906  auto newSliceOp = rewriter.create<tosa::SliceOp>(
907  sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
908  sliceOp.getStart(), size_op);
909 
910  rewriter.replaceOp(sliceOp, newSliceOp.getResult());
911  return success();
912  }
913 };
914 
915 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
916  MLIRContext *context) {
919 }
920 
921 //===----------------------------------------------------------------------===//
922 // Operator Folders.
923 //===----------------------------------------------------------------------===//
924 
925 template <typename IntFolder, typename FloatFolder>
927  RankedTensorType returnTy) {
928  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
929  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
930  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
931  if (lETy != rETy)
932  return {};
933 
934  if (llvm::isa<IntegerType>(lETy)) {
935  APInt l = lhs.getSplatValue<APInt>();
936  APInt r = rhs.getSplatValue<APInt>();
937  auto result = IntFolder()(l, r);
938  return DenseElementsAttr::get(returnTy, result);
939  }
940 
941  if (llvm::isa<FloatType>(lETy)) {
942  APFloat l = lhs.getSplatValue<APFloat>();
943  APFloat r = rhs.getSplatValue<APFloat>();
944  auto result = FloatFolder()(l, r);
945  return DenseElementsAttr::get(returnTy, result);
946  }
947  }
948 
949  return {};
950 }
951 
952 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
953  if (llvm::isa<FloatType>(elemType))
954  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
955  if (llvm::isa<IntegerType>(elemType))
956  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
957  return false;
958 }
959 
960 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
961  if (llvm::isa<FloatType>(elemType))
962  return val && val.isSplat() &&
963  val.getSplatValue<APFloat>().isExactlyValue(1.0);
964  if (llvm::isa<IntegerType>(elemType)) {
965  const int64_t shifted = 1LL << shift;
966  return val && val.isSplat() &&
967  val.getSplatValue<APInt>().getSExtValue() == shifted;
968  }
969  return false;
970 }
971 
972 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
973  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
974  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
975  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
976  if (!lhsTy || !rhsTy || !resultTy)
977  return {};
978 
979  // Cannot create an ElementsAttr from non-int/float/index types
980  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
981  !rhsTy.getElementType().isIntOrIndexOrFloat())
982  return {};
983 
984  auto resultETy = resultTy.getElementType();
985  auto lhsAttr =
986  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
987  auto rhsAttr =
988  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
989 
990  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
991  return getInput1();
992  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
993  return getInput2();
994 
995  if (!lhsAttr || !rhsAttr)
996  return {};
997 
998  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
999  resultTy);
1000 }
1001 
1002 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1003  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
1004  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1005  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1006  !outputTy.hasStaticShape())
1007  return {};
1008 
1009  if (inputTy.getDimSize(getAxis()) == 1)
1010  return DenseElementsAttr::get(outputTy, 0);
1011 
1012  return {};
1013 }
1014 
1015 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1016  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1017  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1018  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1019  if (!lhsTy || !rhsTy || !resultTy)
1020  return {};
1021  if (lhsTy != rhsTy)
1022  return {};
1023 
1024  // IntDivOp inputs must be integer type, no need to check for quantized type
1025  auto resultETy = resultTy.getElementType();
1026  auto lhsAttr =
1027  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1028  auto rhsAttr =
1029  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1030  if (lhsAttr && lhsAttr.isSplat()) {
1031  if (llvm::isa<IntegerType>(resultETy) &&
1032  lhsAttr.getSplatValue<APInt>().isZero())
1033  return lhsAttr;
1034  }
1035 
1036  if (rhsAttr && rhsAttr.isSplat()) {
1037  if (llvm::isa<IntegerType>(resultETy) &&
1038  rhsAttr.getSplatValue<APInt>().isOne())
1039  return getInput1();
1040  }
1041 
1042  if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1043  llvm::isa<IntegerType>(resultETy)) {
1044  APInt l = lhsAttr.getSplatValue<APInt>();
1045  APInt r = rhsAttr.getSplatValue<APInt>();
1046  if (!r.isZero()) {
1047  APInt result = l.sdiv(r);
1048  return DenseElementsAttr::get(resultTy, result);
1049  }
1050  }
1051 
1052  return {};
1053 }
1054 
1055 namespace {
1056 // calculate lhs * rhs >> shift according to TOSA Spec
1057 // return nullopt if result is not in range of int32_t when shift > 0
1058 std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1059  unsigned bitwidth) {
1060  APInt result = lhs.sext(64) * rhs.sext(64);
1061 
1062  if (shift > 0) {
1063  auto round = APInt(64, 1) << (shift - 1);
1064  result += round;
1065  result.ashrInPlace(shift);
1066  // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
1067  if (!(result.getSExtValue() >= INT32_MIN &&
1068  result.getSExtValue() <= INT32_MAX)) {
1069  // REQUIRE failed
1070  return std::nullopt;
1071  }
1072  }
1073 
1074  return result.trunc(bitwidth);
1075 }
1076 
1077 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1078  RankedTensorType ty, int32_t shift) {
1079  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1080  if (llvm::isa<IntegerType>(ty.getElementType())) {
1081  APInt l = lhs.getSplatValue<APInt>();
1082  APInt r = rhs.getSplatValue<APInt>();
1083 
1084  if (shift == 0) {
1085  return DenseElementsAttr::get(ty, l * r);
1086  }
1087 
1088  auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1089  const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1090  if (!result)
1091  return {};
1092  return DenseElementsAttr::get(ty, result.value());
1093  }
1094 
1095  if (llvm::isa<FloatType>(ty.getElementType())) {
1096  APFloat l = lhs.getSplatValue<APFloat>();
1097  APFloat r = rhs.getSplatValue<APFloat>();
1098  APFloat result = l * r;
1099  return DenseElementsAttr::get(ty, result);
1100  }
1101  }
1102 
1103  return {};
1104 }
1105 } // namespace
1106 
1107 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1108  auto lhs = getInput1();
1109  auto rhs = getInput2();
1110  auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1111  auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1112  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1113  if (!lhsTy || !rhsTy || !resultTy)
1114  return {};
1115 
1116  auto resultETy = resultTy.getElementType();
1117  auto lhsAttr =
1118  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1119  auto rhsAttr =
1120  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1121 
1122  // Result right shift on i32_t data type only. For simplification, synthesize
1123  // a zero shift for other data type.
1124  int32_t shift = 0;
1125  if (resultETy.isInteger(32)) {
1126  ElementsAttr shift_elem;
1127  if (getShift().getImpl()) {
1128  if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1129  // cannot be folded when the shift value is unknown.
1130  return {};
1131  shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1132  }
1133  }
1134 
1135  if (rhsTy == resultTy) {
1136  if (isSplatZero(resultETy, lhsAttr))
1137  return lhsAttr.resizeSplat(resultTy);
1138  if (isSplatOne(resultETy, lhsAttr, shift))
1139  return rhs;
1140  }
1141  if (lhsTy == resultTy) {
1142  if (isSplatZero(resultETy, rhsAttr))
1143  return rhsAttr.resizeSplat(resultTy);
1144  if (isSplatOne(resultETy, rhsAttr, shift))
1145  return lhs;
1146  }
1147 
1148  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1149 }
1150 
1151 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1152  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1153  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1154  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1155  if (!lhsTy || !rhsTy || !resultTy)
1156  return {};
1157 
1158  // Cannot create an ElementsAttr from non-int/float/index types
1159  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1160  !rhsTy.getElementType().isIntOrIndexOrFloat())
1161  return {};
1162 
1163  auto resultETy = resultTy.getElementType();
1164  auto lhsAttr =
1165  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1166  auto rhsAttr =
1167  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1168 
1169  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1170  return getInput1();
1171 
1172  if (!lhsAttr || !rhsAttr)
1173  return {};
1174 
1175  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
1176  resultTy);
1177 }
1178 
1179 namespace {
1180 template <typename Cmp>
1181 struct ComparisonFold {
1182  ComparisonFold() = default;
1183  APInt operator()(const APInt &l, const APInt &r) {
1184  return APInt(1, Cmp()(l, r));
1185  }
1186 
1187  APInt operator()(const APFloat &l, const APFloat &r) {
1188  return APInt(1, Cmp()(l, r));
1189  }
1190 };
1191 
1192 struct APIntFoldGreater {
1193  APIntFoldGreater() = default;
1194  APInt operator()(const APInt &l, const APInt &r) {
1195  return APInt(1, l.sgt(r));
1196  }
1197 };
1198 
1199 struct APIntFoldGreaterEqual {
1200  APIntFoldGreaterEqual() = default;
1201  APInt operator()(const APInt &l, const APInt &r) {
1202  return APInt(1, l.sge(r));
1203  }
1204 };
1205 } // namespace
1206 
1207 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1208  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1209  auto lhsAttr =
1210  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1211  auto rhsAttr =
1212  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1213 
1214  if (!lhsAttr || !rhsAttr)
1215  return {};
1216 
1217  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
1218  lhsAttr, rhsAttr, resultTy);
1219 }
1220 
1221 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1222  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1223  auto lhsAttr =
1224  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1225  auto rhsAttr =
1226  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1227 
1228  if (!lhsAttr || !rhsAttr)
1229  return {};
1230 
1231  return binaryFolder<APIntFoldGreaterEqual,
1232  ComparisonFold<std::greater_equal<APFloat>>>(
1233  lhsAttr, rhsAttr, resultTy);
1234 }
1235 
1236 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1237  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1238  auto lhsAttr =
1239  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1240  auto rhsAttr =
1241  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1242  Value lhs = getInput1();
1243  Value rhs = getInput2();
1244  auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1245 
1246  // If we are comparing an integer value to itself it is always true. We can
1247  // not do this with float due to float values.
1248  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1249  resultTy.hasStaticShape() && lhs == rhs) {
1250  return DenseElementsAttr::get(resultTy, true);
1251  }
1252 
1253  if (!lhsAttr || !rhsAttr)
1254  return {};
1255 
1256  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
1257  ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1258  resultTy);
1259 }
1260 
1261 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1262  if (getInput().getType() == getType())
1263  return getInput();
1264 
1265  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1266  if (!operand)
1267  return {};
1268 
1269  auto inTy = llvm::cast<ShapedType>(getInput().getType());
1270  auto outTy = llvm::cast<ShapedType>(getType());
1271  auto inETy = inTy.getElementType();
1272  auto outETy = outTy.getElementType();
1273 
1274  if (operand.isSplat()) {
1275  if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1276  bool overflow;
1277  auto splatVal = operand.getSplatValue<APFloat>();
1278  auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1279  splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1280  &overflow);
1281  return SplatElementsAttr::get(outTy, splatVal);
1282  }
1283 
1284  if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1285  auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1286  APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1287  splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1288  llvm::RoundingMode::NearestTiesToEven);
1289  return SplatElementsAttr::get(outTy, splatVal);
1290  }
1291 
1292  if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1293  auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1294  auto intVal = APSInt(
1295  llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1296  auto floatVal = operand.getSplatValue<APFloat>();
1297  bool exact;
1298  floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1299  &exact);
1300  return SplatElementsAttr::get(outTy, intVal);
1301  }
1302 
1303  if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1304  auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1305  bool trunc =
1306  inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1307  auto intVal = operand.getSplatValue<APInt>();
1308  auto bitwidth = outETy.getIntOrFloatBitWidth();
1309 
1310  if (trunc) {
1311  intVal = intVal.trunc(bitwidth);
1312  } else if (unsignIn) {
1313  intVal = intVal.zext(bitwidth);
1314  } else {
1315  intVal = intVal.sext(bitwidth);
1316  }
1317 
1318  return SplatElementsAttr::get(outTy, intVal);
1319  }
1320  }
1321 
1322  return {};
1323 }
1324 
1325 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1326 
1327 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1328 
1329 #define REDUCE_FOLDER(OP) \
1330  OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1331  ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1332  if (!inputTy.hasRank()) \
1333  return {}; \
1334  if (inputTy != getType()) \
1335  return {}; \
1336  if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1337  return getInput(); \
1338  return {}; \
1339  }
1340 
1341 REDUCE_FOLDER(ReduceAllOp)
1342 REDUCE_FOLDER(ReduceAnyOp)
1343 REDUCE_FOLDER(ReduceMaxOp)
1344 REDUCE_FOLDER(ReduceMinOp)
1345 REDUCE_FOLDER(ReduceProductOp)
1346 REDUCE_FOLDER(ReduceSumOp)
1347 #undef REDUCE_FOLDER
1348 
1349 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1350  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1351  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1352 
1353  if (!inputTy || !outputTy)
1354  return {};
1355 
1356  // Fold when the input and output types are the same. This is only safe when
1357  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
1358  // there may still be a productive reshape.
1359  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1360  return getInput1();
1361 
1362  // reshape(reshape(x)) -> reshape(x)
1363  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1364  getInput1().getDefiningOp())) {
1365  getInput1Mutable().assign(reshapeOp.getInput1());
1366  return getResult();
1367  }
1368 
1369  // Cannot create an ElementsAttr from non-int/float/index types
1370  if (!inputTy.getElementType().isIntOrIndexOrFloat())
1371  return {};
1372 
1373  // reshape(const(x)) -> const(reshape-attr(x))
1374  if (auto operand =
1375  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1376  // Constants must have static shape.
1377  if (!outputTy.hasStaticShape())
1378  return {};
1379 
1380  // Okay to duplicate splat constants.
1381  if (operand.isSplat())
1382  return SplatElementsAttr::get(outputTy,
1383  operand.getSplatValue<Attribute>());
1384 
1385  // Don't duplicate other constants.
1386  if (!getInput1().hasOneUse())
1387  return {};
1388 
1389  llvm::SmallVector<int64_t> shapeVec;
1390  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1391  return {};
1392 
1393  return operand.reshape(
1394  llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1395  }
1396 
1397  return {};
1398 }
1399 
1400 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1401  // If the pad is all zeros we can fold this operation away.
1402  if (adaptor.getPadding() && getInput1().getType() == getType()) {
1403  auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1404  if (densePad && densePad.isSplat() &&
1405  densePad.getSplatValue<APInt>().isZero()) {
1406  return getInput1();
1407  }
1408  }
1409 
1410  return {};
1411 }
1412 
1413 // Fold away cases where a tosa.resize operation returns a copy
1414 // of the input image.
1415 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1416  auto scaleAttr =
1417  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1418  auto offsetAttr =
1419  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1420  auto borderAttr =
1421  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1422  if (!scaleAttr || !offsetAttr || !borderAttr) {
1423  return {};
1424  }
1425 
1426  auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1427  auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1428  auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1429  if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1430  return {};
1431  }
1432 
1433  // Check unit scaling.
1434  if (scale[0] != scale[1] || scale[2] != scale[3]) {
1435  return {};
1436  }
1437 
1438  // There should be no offset.
1439  if (offset[0] != 0 || offset[1] != 0) {
1440  return {};
1441  }
1442 
1443  // There should be no border.
1444  if (border[0] != 0 || border[1] != 0) {
1445  return {};
1446  }
1447 
1448  auto input = getInput();
1449  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1450  auto resultTy = llvm::cast<RankedTensorType>(getType());
1451  if (inputTy != resultTy)
1452  return {};
1453 
1454  return input;
1455 }
1456 
1457 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1458  auto operand = getInput1();
1459  auto operandTy = llvm::cast<ShapedType>(operand.getType());
1460  auto axis = getAxis();
1461  auto operandAttr =
1462  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1463  if (operandAttr)
1464  return operandAttr;
1465 
1466  // If the dim-length is 1, tosa.reverse is a no-op.
1467  if (operandTy.hasRank() &&
1468  (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1469  return operand;
1470 
1471  return {};
1472 }
1473 
1474 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1475  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1476  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1477 
1478  if (!inputTy || !outputTy)
1479  return {};
1480 
1481  if (inputTy == outputTy && inputTy.hasStaticShape())
1482  return getInput1();
1483 
1484  if (!adaptor.getInput1())
1485  return {};
1486 
1487  // Cannot create an ElementsAttr from non-int/float/index types
1488  if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1489  !outputTy.getElementType().isIntOrIndexOrFloat())
1490  return {};
1491 
1492  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1493  if (operand.isSplat() && outputTy.hasStaticShape()) {
1494  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1495  }
1496 
1497  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1498  outputTy.getNumElements() == 1) {
1499  DenseElementsAttr startElems;
1500  if (!matchPattern(getStart(), m_Constant(&startElems)))
1501  return {};
1502 
1503  llvm::SmallVector<uint64_t> indices =
1504  llvm::to_vector(startElems.getValues<uint64_t>());
1505  auto value = operand.getValues<Attribute>()[indices];
1506  return SplatElementsAttr::get(outputTy, value);
1507  }
1508 
1509  return {};
1510 }
1511 
1512 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1513  if (getInput2() == getInput3())
1514  return getInput2();
1515 
1516  auto predicate =
1517  llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1518  if (!predicate)
1519  return {};
1520 
1521  if (!predicate.isSplat())
1522  return {};
1523  return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
1524  : getInput3();
1525 }
1526 
1527 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1528  if (getInput1().getType() == getType()) {
1529  if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1530  adaptor.getMultiples())) {
1531  if (multiples.isSplat() &&
1532  multiples.getSplatValue<APInt>().getSExtValue() == 1)
1533  return getInput1();
1534  if (auto int_array_attr =
1535  llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1536  if (llvm::all_of(int_array_attr.getValues<APInt>(),
1537  [](APInt v) { return v.getSExtValue() == 1; }))
1538  return getInput1();
1539  }
1540  }
1541  }
1542  return {};
1543 }
1544 
1545 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1546  auto resultTy = llvm::cast<ShapedType>(getType());
1547 
1548  // Transposing splat values just means reshaping.
1549  if (auto input =
1550  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1551  if (input.isSplat() && resultTy.hasStaticShape() &&
1552  input.getType().getElementType() == resultTy.getElementType())
1553  return input.reshape(resultTy);
1554  }
1555 
1556  // Transpose is not the identity transpose.
1557  const llvm::ArrayRef<int32_t> perms = getPerms();
1558 
1559  if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1560  return {};
1561 
1562  return getInput1();
1563 }
1564 
1565 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
1566  auto input = getInput1();
1567  // Element-wise log(exp(x)) = x
1568  if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
1569  return op.getInput1();
1570  }
1571 
1572  return {};
1573 }
1574 
1575 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
1576  auto input = getInput1();
1577  // Element-wise exp(log(x)) = x
1578  if (auto op = input.getDefiningOp<tosa::LogOp>()) {
1579  return op.getInput1();
1580  }
1581 
1582  return {};
1583 }
1584 
1585 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1586  // Element-wise negate(negate(x)) = x
1587  // iff all zero points are constant 0
1588  auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1589  if (!definingOp) {
1590  // defining op of input1 is not a negate, cannot fold
1591  return {};
1592  }
1593 
1594  if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1595  failed(maybeIZp) || *maybeIZp != 0) {
1596  // input1 zero point is not constant 0, cannot fold
1597  return {};
1598  }
1599  if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1600  failed(maybeOZp) || *maybeOZp != 0) {
1601  // output zero point is not constant 0, cannot fold
1602  return {};
1603  }
1604  if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1605  failed(maybeIZp) || *maybeIZp != 0) {
1606  // definingOp's input1 zero point is not constant 0, cannot fold
1607  return {};
1608  }
1609  if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1610  failed(maybeOZp) || *maybeOZp != 0) {
1611  // definingOp's output zero point is not constant 0, cannot fold
1612  return {};
1613  }
1614 
1615  return definingOp.getInput1();
1616 }
1617 
1618 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1619  auto input = getInput1();
1620  // Element-wise abs(abs(x)) = abs(x)
1621  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1622  return input;
1623  }
1624 
1625  return {};
1626 }
1627 
1628 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1629  // Fold consecutive concats on the same axis into a single op.
1630  // Keep track of the operands so we are able to construct a new concat
1631  // later. Conservatively assume that we double the number of operands when
1632  // folding
1633  SmallVector<Value, 8> concatOperands;
1634  concatOperands.reserve(2 * getNumOperands());
1635 
1636  // Find all operands that are foldable concats
1637  bool foundFoldableConcat = false;
1638  for (Value operand : getOperands()) {
1639  concatOperands.emplace_back(operand);
1640 
1641  auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1642  if (!producer)
1643  continue;
1644 
1645  // Not foldable if axes are not the same
1646  if (getAxis() != producer.getAxis())
1647  continue;
1648 
1649  // Replace the original operand with all incoming operands
1650  foundFoldableConcat = true;
1651  concatOperands.pop_back();
1652  llvm::append_range(concatOperands, producer->getOperands());
1653  }
1654 
1655  if (!foundFoldableConcat)
1656  return {};
1657 
1658  getOperation()->setOperands(concatOperands);
1659  return getResult();
1660 }
1661 
1662 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1663  auto input = adaptor.getInput1();
1664 
1665  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1666  // Fold splat inputs only.
1667  if (!inputAttr || !inputAttr.isSplat())
1668  return {};
1669 
1670  auto shapeType = llvm::cast<ShapedType>(getType());
1671  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1672  auto floatVal = inputAttr.getSplatValue<APFloat>();
1673  return DenseElementsAttr::get(shapeType,
1674  ReciprocalOp::calcOneElement(floatVal));
1675  }
1676 
1677  return {};
1678 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:204
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319