MLIR  21.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // TOSA canonicalization patterns and folders.
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <functional>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 //===----------------------------------------------------------------------===//
39 // Operator Canonicalizers.
40 //===----------------------------------------------------------------------===//
41 
42 //===----------------------------------------------------------------------===//
43 // Tensor Data Engine Operators.
44 //===----------------------------------------------------------------------===//
45 
46 // Check that the zero point of the tensor and padding operations are aligned.
48  // Check that padConst is a constant value and a scalar tensor
49  DenseElementsAttr padConstAttr;
50  if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
51  (padConstAttr.size() != 1)) {
52  return false;
53  }
54 
55  // Check that floating point pad is zero
56  if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57  float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
58  return padConstVal == 0.0f;
59  }
60 
61  // Check that the zp and padConst align for the integer (quantized) case
62  if (auto padConstIntAttr =
63  mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
64  DenseIntElementsAttr zpAttr;
65  // Check that zp is a constant value and a scalar tensor
66  if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
67  return false;
68  }
69 
70  // Check equality
71  int64_t zpVal = (*zpAttr.begin()).getSExtValue();
72  int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
73  return zpVal == padConstVal;
74  }
75 
76  // Bail-out on unsupported type
77  return false;
78 }
79 
80 namespace {
81 template <typename OpTy>
82 struct PoolPadFoldAdaptor;
83 
84 template <>
85 struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
86  using OpTy = tosa::AvgPool2dOp;
87  static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
88  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
89  if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
90  newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
91  return false;
92  return true;
93  }
94  static bool checkPadConstCompliance(OpTy op, Value padConst) {
95  return checkMatchingPadConstAndZp(padConst, op.getInputZp());
96  }
97  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
98  Value padInput, ArrayRef<int64_t> newPad) {
99  rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
100  op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
101  op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
102  op.getAccType());
103  }
104 };
105 
106 template <>
107 struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
108  using OpTy = tosa::MaxPool2dOp;
109  static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
110  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
111  if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
112  newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
113  return false;
114  return true;
115  }
116  static bool checkPadConstCompliance(OpTy, Value padConst) {
117  // Check that padConst is a constant value and a scalar tensor
118  DenseElementsAttr padConstAttr;
119  if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
120  padConstAttr.size() != 1) {
121  return false;
122  }
123 
124  // Pad needs to be in the minimum value to be able to merge
125  if (auto padConstFpAttr =
126  mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127  const APFloat padConstVal = *padConstFpAttr.begin();
128  const APFloat lowestVal =
129  APFloat::getLargest(padConstVal.getSemantics(), true);
130  return padConstVal == lowestVal;
131  } else if (auto padConstIntAttr =
132  mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
133  const APInt padConstVal = *padConstIntAttr.begin();
134  const unsigned int bitWidth = padConstVal.getBitWidth();
135  const APInt lowestVal =
136  padConstIntAttr.getElementType().isUnsignedInteger()
137  ? APInt::getZero(bitWidth)
138  : APInt::getSignedMinValue(bitWidth);
139  return padConstVal == lowestVal;
140  }
141 
142  // Bail-out on unsupported type
143  return false;
144  }
145  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
146  Value padInput, ArrayRef<int64_t> newPad) {
147  rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
148  op, op.getType(), padInput, op.getKernel(), op.getStride(),
149  rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
150  }
151 };
152 
153 template <typename OpTy>
154 struct ConvPadFoldAdaptor {
155  static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
156  return true;
157  }
158  static bool checkPadConstCompliance(OpTy op, Value padConst) {
159  return checkMatchingPadConstAndZp(padConst, op.getInputZp());
160  }
161  static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
162  Value padInput, ArrayRef<int64_t> newPad) {
163  rewriter.replaceOpWithNewOp<OpTy>(
164  op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
165  op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
166  op.getDilationAttr(), op.getAccType(), op.getLocalBound());
167  }
168 };
169 
170 // Pattern attempts to fold a `tosa.pad` operator to a following tensor
171 // operation like `tosa.conv2d` by merging the padding associated with the
172 // pad operator directly to the implicit padding of the tensor operation.
173 // This helps eliminate the explicit padding operator if unused.
174 template <typename OpTy, typename AdaptorTy>
175 struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
177 
178  LogicalResult matchAndRewrite(OpTy tensorOp,
179  PatternRewriter &rewriter) const override {
180  // Check producer is a tosa::PadOp
181  auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
182  if (!padOp)
183  return rewriter.notifyMatchFailure(tensorOp,
184  "Producer must be a tosa::PadOp.");
185 
186  // Validate that tensor operation has sane padding
187  const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
188  if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
189  return rewriter.notifyMatchFailure(
190  tensorOp, "Tensor operation padding shall have 4 elements.");
191 
192  // Validate tosa::PadOp padding
193  DenseIntElementsAttr padOpPadding;
194  if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
195  return rewriter.notifyMatchFailure(
196  tensorOp,
197  "The `padding` input specified on the tosa::PadOp must be constant.");
198  }
199  // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
200  // C_after
201  if (padOpPadding.size() != 8)
202  return rewriter.notifyMatchFailure(tensorOp,
203  "Pad padding should have 8 elements.");
204  int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
205  int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
206  int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
207  int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
208  int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
209  int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
210  int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
211  int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
212 
213  if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
214  return rewriter.notifyMatchFailure(
215  tensorOp, "Folding padding in N or C dimensions is not supported.");
216 
217  // Fold padding from Pad into the tensor operation
218  // 4 elements - pad_top, pad_bottom, pad_left, pad_right
219  SmallVector<int64_t> foldedPad(tensorOpPad.size());
220  foldedPad[0] = padHBefore + tensorOpPad[0];
221  foldedPad[1] = padHAfter + tensorOpPad[1];
222  foldedPad[2] = padWBefore + tensorOpPad[2];
223  foldedPad[3] = padWAfter + tensorOpPad[3];
224 
225  // Check kernel related restrictions
226  if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
227  return rewriter.notifyMatchFailure(
228  tensorOp, "Padding size not aligned with kernel restrictions.");
229  }
230 
231  // Check padding constant restrictions
232  if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
233  return rewriter.notifyMatchFailure(
234  tensorOp,
235  "Padding constant is not aligned with operator zero-point.");
236  }
237 
238  // Check that padding doesn't grow more than 8K level (8192) for now
239  if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
240  return rewriter.notifyMatchFailure(
241  tensorOp, "Padding size more than the 8K level limit.");
242  }
243 
244  // Create operator
245  AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
246  foldedPad);
247 
248  return success();
249  }
250 };
251 } // namespace
252 
253 void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
254  MLIRContext *context) {
255  results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
256  PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
257  context);
258 }
259 
260 void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
261  MLIRContext *context) {
262  results.add<
263  FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
264  context);
265 }
266 
267 void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
268  MLIRContext *context) {
269  results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
270  ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
271  context);
272 }
273 
274 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
276 
277  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
278  PatternRewriter &rewriter) const override {
279  Value input = op.getInput();
280  Value output = op.getOutput();
281  ShapedType inputType = llvm::cast<ShapedType>(input.getType());
282  ShapedType outputType = llvm::cast<ShapedType>(output.getType());
283 
284  if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
285  return failure();
286  }
287 
288  // If the output and input shapes are 1x1, then this is a no op.
289  ArrayRef<int64_t> outputShape = outputType.getShape();
290  if (outputShape[1] != 1 || outputShape[2] != 1) {
291  return failure();
292  }
293 
294  ArrayRef<int64_t> inputShape = inputType.getShape();
295  if (inputShape[1] != 1 || inputShape[2] != 1) {
296  return failure();
297  }
298 
299  rewriter.replaceOp(op, input);
300  return success();
301  }
302 };
303 
304 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
305  MLIRContext *context) {
306  results.add<MaxPool2dIsNoOp,
307  FoldPadToTensorOp<tosa::MaxPool2dOp,
308  PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
309  context);
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // Data Layout / Memory Reinterpretation.
314 //===----------------------------------------------------------------------===//
315 
316 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
318 
319  LogicalResult matchAndRewrite(tosa::ConcatOp op,
320  PatternRewriter &rewriter) const override {
321  if (op.getInput1().size() != 1)
322  return failure();
323  if (op.getInput1().front().getType() != op.getType()) {
324  rewriter
325  .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
326  op.getInput1().front())
327  .getResult();
328  return success();
329  }
330 
331  rewriter.replaceOp(op, op.getInput1().front());
332  return success();
333  }
334 };
335 
336 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
337  MLIRContext *context) {
338  results.add<ConcatOptimization>(context);
339 }
340 
341 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
342  auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
343  if (!notOp)
344  return failure();
345  rewriter.modifyOpInPlace(op, [&]() {
346  op.getOperation()->setOperands(
347  {notOp.getInput1(), op.getInput3(), op.getInput2()});
348  });
349  return success();
350 }
351 
353  : public OpRewritePattern<tosa::TransposeOp> {
355 
356  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
357  PatternRewriter &rewriter) const override {
358  // Input is also TransposeOp - transpose(transpose(A)).
359  auto innerTranspose =
360  transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
361  if (!innerTranspose)
362  return rewriter.notifyMatchFailure(transposeOp,
363  "input must be transpose operation");
364 
365  const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
366  const llvm::ArrayRef<int32_t> innerTransposePerms =
367  innerTranspose.getPerms();
368 
369  if (transposePerms.size() != innerTransposePerms.size())
370  return rewriter.notifyMatchFailure(
371  transposeOp,
372  "transpose and inner transpose perms sizes must be equal");
373  if (transposePerms.empty())
374  return rewriter.notifyMatchFailure(
375  transposeOp, "transpose perms sizes must be positive");
376 
377  // Consolidate transposes into one transpose.
378  SmallVector<int32_t> perms(transposePerms.size());
379  for (int i = 0, s = transposePerms.size(); i < s; ++i)
380  perms[i] = innerTransposePerms[transposePerms[i]];
381 
382  rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
383  transposeOp, transposeOp.getResult().getType(),
384  innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
385 
386  return success();
387  }
388 };
389 
390 // Determines the case when tosa.transpose is a tosa.reshape operation.
391 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
393 
394  LogicalResult matchAndRewrite(tosa::TransposeOp op,
395  PatternRewriter &rewriter) const override {
396  if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
397  return rewriter.notifyMatchFailure(
398  op, "Src is from transpose, can compose transposes");
399 
400  Value result = op.getResult();
401  for (Operation *subop : result.getUsers()) {
402  if (dyn_cast_or_null<tosa::TransposeOp>(subop))
403  return rewriter.notifyMatchFailure(
404  op, "Dest is used by transpose, can compose transposes");
405  }
406 
407  auto input = op.getInput1();
408  auto inputTy = llvm::cast<ShapedType>(input.getType());
409  if (!inputTy.hasRank())
410  return rewriter.notifyMatchFailure(op, "Unranked input.");
411 
412  int64_t numDynDims = 0;
413  for (int i = 0; i < inputTy.getRank(); ++i)
414  if (inputTy.isDynamicDim(i))
415  numDynDims++;
416 
417  if (numDynDims > 1)
418  return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
419 
420  const llvm::ArrayRef<int32_t> permValues = op.getPerms();
421 
422  SmallVector<int64_t> nonZeroPerms;
423  nonZeroPerms.reserve(permValues.size());
424  for (auto idx : permValues) {
425  auto sz = inputTy.getDimSize(idx);
426  if (sz != 1)
427  nonZeroPerms.push_back(idx);
428  }
429 
430  for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
431  if (nonZeroPerms[i - 1] > nonZeroPerms[i])
432  return rewriter.notifyMatchFailure(op,
433  "Transpose changes memory layout.");
434 
435  SmallVector<int64_t> newShape;
436  newShape.reserve(inputTy.getRank());
437  for (int i = 0, s = inputTy.getRank(); i < s; ++i)
438  newShape.push_back(inputTy.getDimSize(permValues[i]));
439 
440  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
441  op, op.getType(), op.getInput1(),
442  getTosaConstShape(rewriter, op.getLoc(), newShape));
443  return success();
444  }
445 };
446 
447 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
448  MLIRContext *context) {
450 }
451 
452 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
454 
455  LogicalResult matchAndRewrite(tosa::ClampOp op,
456  PatternRewriter &rewriter) const override {
457  Value input = op.getInput();
458  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
459  auto inputElementType = inputType.getElementType();
460 
461  if (!inputType.hasStaticShape()) {
462  return failure();
463  }
464 
465  if (isa<FloatType>(inputElementType)) {
466  // Unlike integer types, floating point types can represent infinity.
467  auto minClamp =
468  llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
469  auto maxClamp =
470  llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
471  bool isMin = minClamp.isNegInfinity();
472  bool isMax = maxClamp.isInfinity();
473 
474  if (isMin && isMax) {
475  rewriter.replaceOp(op, input);
476  return success();
477  }
478  return failure();
479  }
480 
481  if (inputElementType.isUnsignedInteger()) {
482  int64_t minClamp =
483  llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
484  int64_t maxClamp =
485  llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
486 
487  int64_t intMin =
488  APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
489  .getZExtValue();
490  int64_t intMax =
491  APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
492  .getZExtValue();
493 
494  if (minClamp <= intMin && maxClamp >= intMax) {
495  rewriter.replaceOp(op, input);
496  return success();
497  }
498  return failure();
499  }
500 
501  if (llvm::isa<IntegerType>(inputElementType)) {
502  int64_t minClamp =
503  llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
504  int64_t maxClamp =
505  llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
506 
507  int64_t intMin =
508  APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
509  .getSExtValue();
510  int64_t intMax =
511  APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
512  .getSExtValue();
513 
514  if (minClamp <= intMin && maxClamp >= intMax) {
515  rewriter.replaceOp(op, input);
516  return success();
517  }
518  return failure();
519  }
520 
521  return failure();
522  }
523 };
524 
525 // Attempts the following transformation:
526 //
527 // For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
528 // tensor X the following identity holds:
529 //
530 // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
531 //
532 // subject to the following valid NaN propagation semantics:
533 // --------------------------------------------
534 // | OUTER CLAMP | INNER CLAMP | RESULT MODE |
535 // |-------------|--------------|-------------|
536 // | PROPAGATE | PROPAGATE | PROPAGATE |
537 // | PROPAGATE | IGNORE | IGNORE |
538 // | IGNORE | PROPAGATE | INVALID |
539 // | IGNORE | IGNORE | IGNORE |
540 // |------------------------------------------|
541 
542 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
544 
545  // Helper structure to describe the range of a clamp operation.
546  template <typename T>
547  struct ClampRange {
548  ClampRange(const T &start, const T &end) : start(start), end(end) {}
549  T start;
550  T end;
551 
552  // Helper function to determine if two Clamp ranges intersect.
553  bool intersects(const ClampRange<T> &otherRange) {
554  return start < otherRange.end && otherRange.start < end;
555  }
556  };
557 
558  LogicalResult matchAndRewrite(tosa::ClampOp op,
559  PatternRewriter &rewriter) const override {
560  Value input = op.getInput();
561 
562  // Check the input to the CLAMP op is itself a CLAMP.
563  auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
564  if (!clampOp)
565  return failure();
566 
567  // Check we have a valid NaN propagation combination.
568  const auto opNanMode = op.getNanMode();
569  const auto clampNanMode = clampOp.getNanMode();
570  if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
571  return failure();
572 
573  auto maxValAttr = op.getMaxValAttr();
574  auto minValAttr = op.getMinValAttr();
575  auto clampOpMaxValAttr = clampOp.getMaxValAttr();
576  auto clampOpMinValAttr = clampOp.getMinValAttr();
577 
578  auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
579  if (auto quantType =
580  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
581  inputEType = quantType.getStorageType();
582  }
583 
584  Attribute newMinValAttr, newMaxValAttr;
585  if (mlir::isa<FloatType>(inputEType)) {
586  auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
587  auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
588  auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
589  auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
590 
591  // Check we have intersecting ranges.
592  const auto opMinFloat = floatMinValAttr.getValue();
593  const auto opMaxFloat = floatMaxValAttr.getValue();
594  const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
595  const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
596  ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
597  ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
598  clampOpMaxFloat);
599  if (!opRangeFloatRange.intersects(clampRangeFloatRange))
600  return failure();
601 
602  // Run the transformation.
603  auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
604  auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
605  newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
606  newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
607  } else {
608  assert(mlir::isa<IntegerType>(inputEType));
609  auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
610  auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
611  auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
612  auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
613 
614  if (inputEType.isUnsignedInteger()) {
615  // Check we have intersecting ranges.
616  const auto opMinInt = intMinValAttr.getUInt();
617  const auto opMaxInt = intMaxValAttr.getUInt();
618  const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
619  const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
620  ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
621  ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
622  clampOpMaxInt);
623  if (!opRangeIntRange.intersects(clampRangeIntRange))
624  return failure();
625 
626  // Run the transformation.
627  auto newMinVal = std::max(opMinInt, clampOpMinInt);
628  auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
629  newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
630  newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
631  } else {
632  // Check we have intersecting ranges.
633  const auto opMinInt = intMinValAttr.getInt();
634  const auto opMaxInt = intMaxValAttr.getInt();
635  const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
636  const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
637  ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
638  ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
639  clampOpMaxInt);
640  if (!opRangeIntRange.intersects(clampRangeIntRange))
641  return failure();
642 
643  // Run the transformation.
644  auto newMinVal = std::max(opMinInt, clampOpMinInt);
645  auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
646  newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
647  newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
648  }
649  }
650 
651  rewriter.replaceOpWithNewOp<tosa::ClampOp>(
652  op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
653  rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
654  : opNanMode));
655  return success();
656  }
657 };
658 
659 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
660  MLIRContext *context) {
661  results.add<ClampIsNoOp>(context);
662  results.add<ClampClampOptimization>(context);
663 }
664 
665 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
667 
668  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
669  PatternRewriter &rewriter) const override {
670  Value sliceInput = sliceOp.getInput1();
671  auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
672  if (!concatOp)
673  return rewriter.notifyMatchFailure(
674  sliceOp, "slice input must be concat operation");
675 
676  OperandRange inputs = concatOp.getInput1();
677  auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
678  if (!concatType || !concatType.hasStaticShape())
679  return rewriter.notifyMatchFailure(
680  sliceOp, "slice input must be a static ranked tensor");
681  int32_t axis = concatOp.getAxis();
682 
683  DenseElementsAttr startElems;
684  DenseElementsAttr sizeElems;
685 
686  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
687  return rewriter.notifyMatchFailure(
688  sliceOp, "start of slice must be a static ranked shape");
689 
690  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
691  return rewriter.notifyMatchFailure(
692  sliceOp, "size of slice must be a static ranked shape");
693 
694  llvm::SmallVector<int64_t> sliceStarts =
695  llvm::to_vector(startElems.getValues<int64_t>());
696  llvm::SmallVector<int64_t> sliceSizes =
697  llvm::to_vector(sizeElems.getValues<int64_t>());
698 
699  // Validate slice on the concatenated axis. Slicing along this
700  // axis should span only one of the inputs to the concatenate
701  // operation.
702  std::optional<Value> replaceWithSlice;
703  for (auto input : inputs) {
704  auto inputType = dyn_cast<RankedTensorType>(input.getType());
705  if (!inputType || !inputType.hasStaticShape())
706  return rewriter.notifyMatchFailure(
707  sliceOp, "concat input must be a static ranked tensor");
708 
709  if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
710  inputType.getDimSize(axis)) {
711  auto start_op =
712  getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
713  auto size_op =
714  getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
715  replaceWithSlice =
716  rewriter
717  .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
718  input, start_op, size_op)
719  .getResult();
720  break;
721  }
722  sliceStarts[axis] -= inputType.getDimSize(axis);
723  }
724 
725  if (!replaceWithSlice)
726  return rewriter.notifyMatchFailure(
727  sliceOp, "corresponding concat input not found for slice");
728 
729  rewriter.replaceOp(sliceOp, replaceWithSlice.value());
730  return success();
731  }
732 };
733 
734 // Update size operand of tosa.slice if size has dynamic dims but corresponding
735 // output dim is static
737  : public OpRewritePattern<tosa::SliceOp> {
739 
740  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
741  PatternRewriter &rewriter) const override {
742  ShapedType resultType = cast<ShapedType>(sliceOp.getType());
743 
744  ElementsAttr sizeElems;
745  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
746  return rewriter.notifyMatchFailure(
747  sliceOp, "size of slice must be a static ranked shape");
748  }
749 
750  llvm::SmallVector<int64_t> sliceSizes =
751  llvm::to_vector(sizeElems.getValues<int64_t>());
752 
753  bool replaceSliceSize{false};
754  // if size op has -1 indicating dynamic shape but corresponding dim on the
755  // output is statically known, update size to match with known output dim
756  // shape
757  for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
758  if (size == -1 && !resultType.isDynamicDim(index)) {
759  sliceSizes[index] = resultType.getDimSize(index);
760  replaceSliceSize = true;
761  }
762  }
763 
764  if (!replaceSliceSize) {
765  return rewriter.notifyMatchFailure(
766  sliceOp, "no dimension of size of slice is dynamic that resolves "
767  "to static output shape");
768  }
769 
770  auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
771  auto newSliceOp = rewriter.create<tosa::SliceOp>(
772  sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
773  sliceOp.getStart(), size_op);
774 
775  rewriter.replaceOp(sliceOp, newSliceOp.getResult());
776  return success();
777  }
778 };
779 
780 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
781  MLIRContext *context) {
783  context);
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // Operator Folders.
788 //===----------------------------------------------------------------------===//
789 
790 template <typename IntFolder, typename FloatFolder>
792  RankedTensorType returnTy) {
793  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
794  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
795  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
796  if (lETy != rETy)
797  return {};
798 
799  if (llvm::isa<IntegerType>(lETy)) {
800  APInt l = lhs.getSplatValue<APInt>();
801  APInt r = rhs.getSplatValue<APInt>();
802  auto result = IntFolder()(l, r);
803  return DenseElementsAttr::get(returnTy, result);
804  }
805 
806  if (llvm::isa<FloatType>(lETy)) {
807  APFloat l = lhs.getSplatValue<APFloat>();
808  APFloat r = rhs.getSplatValue<APFloat>();
809  auto result = FloatFolder()(l, r);
810  return DenseElementsAttr::get(returnTy, result);
811  }
812  }
813 
814  return {};
815 }
816 
817 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
818  if (llvm::isa<FloatType>(elemType))
819  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
820  if (llvm::isa<IntegerType>(elemType))
821  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
822  return false;
823 }
824 
825 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
826  if (llvm::isa<FloatType>(elemType))
827  return val && val.isSplat() &&
828  val.getSplatValue<APFloat>().isExactlyValue(1.0);
829  if (llvm::isa<IntegerType>(elemType)) {
830  const int64_t shifted = 1LL << shift;
831  return val && val.isSplat() &&
832  val.getSplatValue<APInt>().getSExtValue() == shifted;
833  }
834  return false;
835 }
836 
837 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
838  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
839  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
840  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
841  if (!lhsTy || !rhsTy || !resultTy)
842  return {};
843 
844  // Cannot create an ElementsAttr from non-int/float/index types
845  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
846  !rhsTy.getElementType().isIntOrIndexOrFloat())
847  return {};
848 
849  auto resultETy = resultTy.getElementType();
850  auto lhsAttr =
851  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
852  auto rhsAttr =
853  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
854 
855  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
856  return getInput1();
857  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
858  return getInput2();
859 
860  if (!lhsAttr || !rhsAttr)
861  return {};
862 
863  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
864  resultTy);
865 }
866 
867 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
868  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
869  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
870  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
871  !outputTy.hasStaticShape())
872  return {};
873 
874  if (inputTy.getDimSize(getAxis()) == 1)
875  return DenseElementsAttr::get(outputTy, 0);
876 
877  return {};
878 }
879 
880 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
881  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
882  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
883  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
884  if (!lhsTy || !rhsTy || !resultTy)
885  return {};
886  if (lhsTy != rhsTy)
887  return {};
888 
889  // IntDivOp inputs must be integer type, no need to check for quantized type
890  auto resultETy = resultTy.getElementType();
891  auto lhsAttr =
892  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
893  auto rhsAttr =
894  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
895  if (lhsAttr && lhsAttr.isSplat()) {
896  if (llvm::isa<IntegerType>(resultETy) &&
897  lhsAttr.getSplatValue<APInt>().isZero())
898  return lhsAttr;
899  }
900 
901  if (rhsAttr && rhsAttr.isSplat()) {
902  if (llvm::isa<IntegerType>(resultETy) &&
903  rhsAttr.getSplatValue<APInt>().isOne())
904  return getInput1();
905  }
906 
907  if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
908  llvm::isa<IntegerType>(resultETy)) {
909  APInt l = lhsAttr.getSplatValue<APInt>();
910  APInt r = rhsAttr.getSplatValue<APInt>();
911  if (!r.isZero()) {
912  APInt result = l.sdiv(r);
913  return DenseElementsAttr::get(resultTy, result);
914  }
915  }
916 
917  return {};
918 }
919 
920 namespace {
921 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
922  RankedTensorType ty, int32_t shift) {
923  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
924  if (llvm::isa<IntegerType>(ty.getElementType())) {
925  APInt l = lhs.getSplatValue<APInt>();
926  APInt r = rhs.getSplatValue<APInt>();
927 
928  if (shift == 0) {
929  return DenseElementsAttr::get(ty, l * r);
930  }
931 
932  auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
933  l = l.sext(bitwidth * 2);
934  r = r.sext(bitwidth * 2);
935  auto result = l * r;
936  result.lshrInPlace(shift);
937  result = result.trunc(bitwidth);
938  return DenseElementsAttr::get(ty, result);
939  }
940 
941  if (llvm::isa<FloatType>(ty.getElementType())) {
942  APFloat l = lhs.getSplatValue<APFloat>();
943  APFloat r = rhs.getSplatValue<APFloat>();
944  APFloat result = l * r;
945  return DenseElementsAttr::get(ty, result);
946  }
947  }
948 
949  return {};
950 }
951 } // namespace
952 
953 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
954  auto lhs = getInput1();
955  auto rhs = getInput2();
956  auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
957  auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
958  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
959  if (!lhsTy || !rhsTy || !resultTy)
960  return {};
961 
962  auto resultETy = resultTy.getElementType();
963  auto lhsAttr =
964  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
965  auto rhsAttr =
966  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
967 
968  // Result right shift on i32_t data type only. For simplification, synthesize
969  // a zero shift for other data type.
970  int32_t shift = 0;
971  if (resultETy.isInteger(32)) {
972  ElementsAttr shift_elem;
973  if (getShift().getImpl()) {
974  if (!matchPattern(getShift(), m_Constant(&shift_elem)))
975  // cannot be folded when the shift value is unknown.
976  return {};
977  shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
978  }
979  }
980 
981  if (rhsTy == resultTy) {
982  if (isSplatZero(resultETy, lhsAttr))
983  return lhsAttr.resizeSplat(resultTy);
984  if (isSplatOne(resultETy, lhsAttr, shift))
985  return rhs;
986  }
987  if (lhsTy == resultTy) {
988  if (isSplatZero(resultETy, rhsAttr))
989  return rhsAttr.resizeSplat(resultTy);
990  if (isSplatOne(resultETy, rhsAttr, shift))
991  return lhs;
992  }
993 
994  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
995 }
996 
997 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
998  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
999  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1000  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1001  if (!lhsTy || !rhsTy || !resultTy)
1002  return {};
1003 
1004  // Cannot create an ElementsAttr from non-int/float/index types
1005  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1006  !rhsTy.getElementType().isIntOrIndexOrFloat())
1007  return {};
1008 
1009  auto resultETy = resultTy.getElementType();
1010  auto lhsAttr =
1011  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1012  auto rhsAttr =
1013  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1014 
1015  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1016  return getInput1();
1017 
1018  if (!lhsAttr || !rhsAttr)
1019  return {};
1020 
1021  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
1022  resultTy);
1023 }
1024 
1025 namespace {
1026 template <typename Cmp>
1027 struct ComparisonFold {
1028  ComparisonFold() = default;
1029  APInt operator()(const APInt &l, const APInt &r) {
1030  return APInt(1, Cmp()(l, r));
1031  }
1032 
1033  APInt operator()(const APFloat &l, const APFloat &r) {
1034  return APInt(1, Cmp()(l, r));
1035  }
1036 };
1037 
1038 struct APIntFoldGreater {
1039  APIntFoldGreater() = default;
1040  APInt operator()(const APInt &l, const APInt &r) {
1041  return APInt(1, l.sgt(r));
1042  }
1043 };
1044 
1045 struct APIntFoldGreaterEqual {
1046  APIntFoldGreaterEqual() = default;
1047  APInt operator()(const APInt &l, const APInt &r) {
1048  return APInt(1, l.sge(r));
1049  }
1050 };
1051 } // namespace
1052 
1053 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1054  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1055  auto lhsAttr =
1056  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1057  auto rhsAttr =
1058  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1059 
1060  if (!lhsAttr || !rhsAttr)
1061  return {};
1062 
1063  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
1064  lhsAttr, rhsAttr, resultTy);
1065 }
1066 
1067 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1068  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1069  auto lhsAttr =
1070  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1071  auto rhsAttr =
1072  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1073 
1074  if (!lhsAttr || !rhsAttr)
1075  return {};
1076 
1077  return binaryFolder<APIntFoldGreaterEqual,
1078  ComparisonFold<std::greater_equal<APFloat>>>(
1079  lhsAttr, rhsAttr, resultTy);
1080 }
1081 
1082 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1083  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1084  auto lhsAttr =
1085  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1086  auto rhsAttr =
1087  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1088  Value lhs = getInput1();
1089  Value rhs = getInput2();
1090  auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1091 
1092  // If we are comparing an integer value to itself it is always true. We can
1093  // not do this with float due to float values.
1094  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1095  resultTy.hasStaticShape() && lhs == rhs) {
1096  return DenseElementsAttr::get(resultTy, true);
1097  }
1098 
1099  if (!lhsAttr || !rhsAttr)
1100  return {};
1101 
1102  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
1103  ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1104  resultTy);
1105 }
1106 
1107 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1108  if (getInput().getType() == getType())
1109  return getInput();
1110 
1111  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1112  if (!operand)
1113  return {};
1114 
1115  auto inTy = llvm::cast<ShapedType>(getInput().getType());
1116  auto outTy = llvm::cast<ShapedType>(getType());
1117  auto inETy = inTy.getElementType();
1118  auto outETy = outTy.getElementType();
1119 
1120  if (operand.isSplat()) {
1121  if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1122  bool overflow;
1123  auto splatVal = operand.getSplatValue<APFloat>();
1124  auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1125  splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1126  &overflow);
1127  return SplatElementsAttr::get(outTy, splatVal);
1128  }
1129 
1130  if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1131  auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1132  APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1133  splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1134  llvm::RoundingMode::NearestTiesToEven);
1135  return SplatElementsAttr::get(outTy, splatVal);
1136  }
1137 
1138  if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1139  auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1140  auto intVal = APSInt(
1141  llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1142  auto floatVal = operand.getSplatValue<APFloat>();
1143  bool exact;
1144  floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1145  &exact);
1146  return SplatElementsAttr::get(outTy, intVal);
1147  }
1148 
1149  if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1150  auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1151  bool trunc =
1152  inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1153  auto intVal = operand.getSplatValue<APInt>();
1154  auto bitwidth = outETy.getIntOrFloatBitWidth();
1155 
1156  if (trunc) {
1157  intVal = intVal.trunc(bitwidth);
1158  } else if (unsignIn) {
1159  intVal = intVal.zext(bitwidth);
1160  } else {
1161  intVal = intVal.sext(bitwidth);
1162  }
1163 
1164  return SplatElementsAttr::get(outTy, intVal);
1165  }
1166  }
1167 
1168  return {};
1169 }
1170 
1171 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1172 
1173 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1174 
1175 #define REDUCE_FOLDER(OP) \
1176  OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1177  ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1178  if (!inputTy.hasRank()) \
1179  return {}; \
1180  if (inputTy != getType()) \
1181  return {}; \
1182  if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1183  return getInput(); \
1184  return {}; \
1185  }
1186 
1187 REDUCE_FOLDER(ReduceAllOp)
1188 REDUCE_FOLDER(ReduceAnyOp)
1189 REDUCE_FOLDER(ReduceMaxOp)
1190 REDUCE_FOLDER(ReduceMinOp)
1191 REDUCE_FOLDER(ReduceProductOp)
1192 REDUCE_FOLDER(ReduceSumOp)
1193 #undef REDUCE_FOLDER
1194 
1195 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1196  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1197  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1198 
1199  if (!inputTy || !outputTy)
1200  return {};
1201 
1202  // Fold when the input and output types are the same. This is only safe when
1203  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
1204  // there may still be a productive reshape.
1205  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1206  return getInput1();
1207 
1208  // reshape(reshape(x)) -> reshape(x)
1209  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1210  getInput1().getDefiningOp())) {
1211  getInput1Mutable().assign(reshapeOp.getInput1());
1212  return getResult();
1213  }
1214 
1215  // Cannot create an ElementsAttr from non-int/float/index types
1216  if (!inputTy.getElementType().isIntOrIndexOrFloat())
1217  return {};
1218 
1219  // reshape(const(x)) -> const(reshape-attr(x))
1220  if (auto operand =
1221  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1222  // Constants must have static shape.
1223  if (!outputTy.hasStaticShape())
1224  return {};
1225 
1226  // Okay to duplicate splat constants.
1227  if (operand.isSplat())
1228  return SplatElementsAttr::get(outputTy,
1229  operand.getSplatValue<Attribute>());
1230 
1231  // Don't duplicate other constants.
1232  if (!getInput1().hasOneUse())
1233  return {};
1234 
1235  llvm::SmallVector<int64_t> shapeVec;
1236  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1237  return {};
1238 
1239  return operand.reshape(
1240  llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1241  }
1242 
1243  return {};
1244 }
1245 
1246 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1247  // If the pad is all zeros we can fold this operation away.
1248  if (adaptor.getPadding() && getInput1().getType() == getType()) {
1249  auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1250  if (densePad && densePad.isSplat() &&
1251  densePad.getSplatValue<APInt>().isZero()) {
1252  return getInput1();
1253  }
1254  }
1255 
1256  return {};
1257 }
1258 
1259 // Fold away cases where a tosa.resize operation returns a copy
1260 // of the input image.
1261 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1262  auto scaleAttr =
1263  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1264  auto offsetAttr =
1265  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1266  auto borderAttr =
1267  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1268  if (!scaleAttr || !offsetAttr || !borderAttr) {
1269  return {};
1270  }
1271 
1272  auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1273  auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1274  auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1275  if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1276  return {};
1277  }
1278 
1279  // Check unit scaling.
1280  if (scale[0] != scale[1] || scale[2] != scale[3]) {
1281  return {};
1282  }
1283 
1284  // There should be no offset.
1285  if (offset[0] != 0 || offset[1] != 0) {
1286  return {};
1287  }
1288 
1289  // There should be no border.
1290  if (border[0] != 0 || border[1] != 0) {
1291  return {};
1292  }
1293 
1294  auto input = getInput();
1295  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1296  auto resultTy = llvm::cast<RankedTensorType>(getType());
1297  if (inputTy != resultTy)
1298  return {};
1299 
1300  return input;
1301 }
1302 
1303 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1304  auto operand = getInput1();
1305  auto operandTy = llvm::cast<ShapedType>(operand.getType());
1306  auto axis = getAxis();
1307  auto operandAttr =
1308  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1309  if (operandAttr)
1310  return operandAttr;
1311 
1312  // If the dim-length is 1, tosa.reverse is a no-op.
1313  if (operandTy.hasRank() &&
1314  (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1315  return operand;
1316 
1317  return {};
1318 }
1319 
1320 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1321  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1322  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1323 
1324  if (!inputTy || !outputTy)
1325  return {};
1326 
1327  if (inputTy == outputTy && inputTy.hasStaticShape())
1328  return getInput1();
1329 
1330  if (!adaptor.getInput1())
1331  return {};
1332 
1333  // Cannot create an ElementsAttr from non-int/float/index types
1334  if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1335  !outputTy.getElementType().isIntOrIndexOrFloat())
1336  return {};
1337 
1338  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1339  if (operand.isSplat() && outputTy.hasStaticShape()) {
1340  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1341  }
1342 
1343  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1344  outputTy.getNumElements() == 1) {
1345  DenseElementsAttr startElems;
1346  if (!matchPattern(getStart(), m_Constant(&startElems)))
1347  return {};
1348 
1349  llvm::SmallVector<uint64_t> indices =
1350  llvm::to_vector(startElems.getValues<uint64_t>());
1351  auto value = operand.getValues<Attribute>()[indices];
1352  return SplatElementsAttr::get(outputTy, value);
1353  }
1354 
1355  return {};
1356 }
1357 
1358 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1359  if (getInput2() == getInput3())
1360  return getInput2();
1361 
1362  auto predicate =
1363  llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1364  if (!predicate)
1365  return {};
1366 
1367  if (!predicate.isSplat())
1368  return {};
1369  return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
1370  : getInput3();
1371 }
1372 
1373 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1374  if (getInput1().getType() == getType()) {
1375  if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1376  adaptor.getMultiples())) {
1377  if (multiples.isSplat() &&
1378  multiples.getSplatValue<APInt>().getSExtValue() == 1)
1379  return getInput1();
1380  if (auto int_array_attr =
1381  llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1382  if (llvm::all_of(int_array_attr.getValues<APInt>(),
1383  [](APInt v) { return v.getSExtValue() == 1; }))
1384  return getInput1();
1385  }
1386  }
1387  }
1388  return {};
1389 }
1390 
1391 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1392  auto resultTy = llvm::cast<ShapedType>(getType());
1393 
1394  // Transposing splat values just means reshaping.
1395  if (auto input =
1396  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1397  if (input.isSplat() && resultTy.hasStaticShape() &&
1398  input.getType().getElementType() == resultTy.getElementType())
1399  return input.reshape(resultTy);
1400  }
1401 
1402  // Transpose is not the identity transpose.
1403  const llvm::ArrayRef<int32_t> perms = getPerms();
1404 
1405  if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1406  return {};
1407 
1408  return getInput1();
1409 }
1410 
1411 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
1412  auto input = getInput1();
1413  // Element-wise log(exp(x)) = x
1414  if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
1415  return op.getInput1();
1416  }
1417 
1418  return {};
1419 }
1420 
1421 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
1422  auto input = getInput1();
1423  // Element-wise exp(log(x)) = x
1424  if (auto op = input.getDefiningOp<tosa::LogOp>()) {
1425  return op.getInput1();
1426  }
1427 
1428  return {};
1429 }
1430 
1431 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1432  // Element-wise negate(negate(x)) = x
1433  // iff all zero points are constant 0
1434  auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1435  if (!definingOp) {
1436  // defining op of input1 is not a negate, cannot fold
1437  return {};
1438  }
1439 
1440  if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1441  failed(maybeIZp) || *maybeIZp != 0) {
1442  // input1 zero point is not constant 0, cannot fold
1443  return {};
1444  }
1445  if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1446  failed(maybeOZp) || *maybeOZp != 0) {
1447  // output zero point is not constant 0, cannot fold
1448  return {};
1449  }
1450  if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1451  failed(maybeIZp) || *maybeIZp != 0) {
1452  // definingOp's input1 zero point is not constant 0, cannot fold
1453  return {};
1454  }
1455  if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1456  failed(maybeOZp) || *maybeOZp != 0) {
1457  // definingOp's output zero point is not constant 0, cannot fold
1458  return {};
1459  }
1460 
1461  return definingOp.getInput1();
1462 }
1463 
1464 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1465  auto input = getInput1();
1466  // Element-wise abs(abs(x)) = abs(x)
1467  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1468  return input;
1469  }
1470 
1471  return {};
1472 }
1473 
1474 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1475  // Fold consecutive concats on the same axis into a single op.
1476  // Keep track of the operands so we are able to construct a new concat
1477  // later. Conservatively assume that we double the number of operands when
1478  // folding
1479  SmallVector<Value, 8> concatOperands;
1480  concatOperands.reserve(2 * getNumOperands());
1481 
1482  // Find all operands that are foldable concats
1483  bool foundFoldableConcat = false;
1484  for (Value operand : getOperands()) {
1485  concatOperands.emplace_back(operand);
1486 
1487  auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1488  if (!producer)
1489  continue;
1490 
1491  // Not foldable if axes are not the same
1492  if (getAxis() != producer.getAxis())
1493  continue;
1494 
1495  // Replace the original operand with all incoming operands
1496  foundFoldableConcat = true;
1497  concatOperands.pop_back();
1498  llvm::append_range(concatOperands, producer->getOperands());
1499  }
1500 
1501  if (!foundFoldableConcat)
1502  return {};
1503 
1504  getOperation()->setOperands(concatOperands);
1505  return getResult();
1506 }
1507 
1508 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1509  auto input = adaptor.getInput1();
1510 
1511  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1512  // Fold splat inputs only.
1513  if (!inputAttr || !inputAttr.isSplat())
1514  return {};
1515 
1516  auto shapeType = llvm::cast<ShapedType>(getType());
1517  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1518  auto floatVal = inputAttr.getSplatValue<APFloat>();
1519  return DenseElementsAttr::get(shapeType,
1520  ReciprocalOp::calcOneElement(floatVal));
1521  }
1522 
1523  return {};
1524 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:204
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319