MLIR  19.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // TOSA canonicalization patterns and folders.
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <functional>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 //===----------------------------------------------------------------------===//
39 // Operator Canonicalizers.
40 //===----------------------------------------------------------------------===//
41 
42 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
44 
45  LogicalResult matchAndRewrite(tosa::ConcatOp op,
46  PatternRewriter &rewriter) const override {
47  if (op.getInput1().size() != 1)
48  return failure();
49  if (op.getInput1().front().getType() != op.getType()) {
50  rewriter
51  .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
52  op.getInput1().front())
53  .getResult();
54  return success();
55  }
56 
57  rewriter.replaceOp(op, op.getInput1().front());
58  return success();
59  }
60 };
61 
62 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
63  MLIRContext *context) {
64  results.add<ConcatOptimization>(context);
65 }
66 
67 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
68  auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
69  if (!notOp)
70  return failure();
71  rewriter.modifyOpInPlace(op, [&]() {
72  op.getOperation()->setOperands(
73  {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
74  });
75  return success();
76 }
77 
79  : public OpRewritePattern<tosa::TransposeOp> {
81 
82  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
83  PatternRewriter &rewriter) const override {
84  // Input is also TransposeOp - transpose(transpose(A)).
85  auto innerTranspose =
86  transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
87  if (!innerTranspose)
88  return rewriter.notifyMatchFailure(transposeOp,
89  "input must be transpose operation");
90 
91  SmallVector<int64_t> transposePerms, innerTransposePerms;
92  if (transposeOp.getConstantPerms(transposePerms).failed())
93  return rewriter.notifyMatchFailure(transposeOp,
94  "transpose perms must be constant");
95  if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
96  return rewriter.notifyMatchFailure(
97  transposeOp, "inner transpose perms must be constant");
98  if (transposePerms.size() != innerTransposePerms.size())
99  return rewriter.notifyMatchFailure(
100  transposeOp,
101  "transpose and inner transpose perms sizes must be equal");
102  if (transposePerms.empty())
103  return rewriter.notifyMatchFailure(
104  transposeOp, "transpose perms sizes must be positive");
105 
106  // Consolidate transposes into one transpose.
107  SmallVector<int32_t> perms(transposePerms.size());
108  for (int i = 0, s = transposePerms.size(); i < s; ++i)
109  perms[i] = innerTransposePerms[transposePerms[i]];
110 
111  auto permsTy =
112  RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
113  auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
114  Value permsValue =
115  rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
116 
117  rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
118  transposeOp, transposeOp.getResult().getType(),
119  innerTranspose.getInput1(), permsValue);
120 
121  return success();
122  }
123 };
124 
125 // Determines the case when tosa.transpose is a tosa.reshape operation.
126 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
128 
129  LogicalResult matchAndRewrite(tosa::TransposeOp op,
130  PatternRewriter &rewriter) const override {
131  DenseIntElementsAttr permAttr;
132  if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
133  return rewriter.notifyMatchFailure(op, "Non-constant permutation");
134 
135  if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
136  return rewriter.notifyMatchFailure(
137  op, "Src is from transpose, can compose transposes");
138 
139  Value result = op.getResult();
140  for (Operation *subop : result.getUsers()) {
141  if (dyn_cast_or_null<tosa::TransposeOp>(subop))
142  return rewriter.notifyMatchFailure(
143  op, "Dest is used by transpose, can compose transposes");
144  }
145 
146  auto input = op.getInput1();
147  auto inputTy = llvm::cast<ShapedType>(input.getType());
148  if (!inputTy.hasRank())
149  return rewriter.notifyMatchFailure(op, "Unranked input.");
150 
151  int64_t numDynDims = 0;
152  for (int i = 0; i < inputTy.getRank(); ++i)
153  if (inputTy.isDynamicDim(i))
154  numDynDims++;
155 
156  if (numDynDims > 1)
157  return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
158 
159  SmallVector<int64_t> permValues = llvm::to_vector<6>(
160  llvm::map_range(permAttr.getValues<APInt>(),
161  [](const APInt &val) { return val.getSExtValue(); }));
162 
163  SmallVector<int64_t> nonZeroPerms;
164  nonZeroPerms.reserve(permValues.size());
165  for (auto idx : permValues) {
166  auto sz = inputTy.getDimSize(idx);
167  if (sz != 1)
168  nonZeroPerms.push_back(idx);
169  }
170 
171  for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
172  if (nonZeroPerms[i - 1] > nonZeroPerms[i])
173  return rewriter.notifyMatchFailure(op,
174  "Transpose changes memory layout.");
175 
176  SmallVector<int64_t> newShape;
177  newShape.reserve(inputTy.getRank());
178  for (int i = 0, s = inputTy.getRank(); i < s; ++i)
179  newShape.push_back(inputTy.getDimSize(permValues[i]));
180 
181  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
182  op, op.getType(), op.getInput1(),
183  rewriter.getDenseI64ArrayAttr(newShape));
184  return success();
185  }
186 };
187 
188 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
189  MLIRContext *context) {
191 }
192 
193 struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
195 
197  PatternRewriter &rewriter) const override {
198  if (op.getPadConst())
199  return failure();
200 
201  auto input = op.getInput1();
202  auto padding = op.getPadding();
203 
204  ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
205  Type elementTy = inputTy.getElementType();
206 
207  Attribute constantAttr;
208  if (llvm::isa<FloatType>(elementTy)) {
209  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
210  } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
211  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
212  } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
213  auto value = op.getQuantizationInfo()->getInputZp();
214  constantAttr = rewriter.getIntegerAttr(elementTy, value);
215  }
216 
217  if (!constantAttr) {
218  return rewriter.notifyMatchFailure(
219  op,
220  "tosa.pad to linalg lowering encountered an unknown element type");
221  }
222 
223  auto denseAttr = DenseElementsAttr::get(
224  RankedTensorType::get({}, elementTy), constantAttr);
225  auto constantVal = rewriter.create<tosa::ConstOp>(
226  op.getLoc(), denseAttr.getType(), denseAttr);
227 
228  rewriter.replaceOpWithNewOp<tosa::PadOp>(
229  op, op.getType(), ValueRange{input, padding, constantVal},
230  op->getAttrs());
231  return success();
232  }
233 };
234 
235 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
236  MLIRContext *context) {
237  results.add<MaterializePadValue>(context);
238 }
239 
240 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
242 
243  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
244  PatternRewriter &rewriter) const override {
245  Value input = op.getInput();
246  Value output = op.getOutput();
247  ShapedType inputType = llvm::cast<ShapedType>(input.getType());
248  ShapedType outputType = llvm::cast<ShapedType>(output.getType());
249 
250  if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
251  return failure();
252  }
253 
254  // If the output and input shapes are 1x1, then this is a no op.
255  ArrayRef<int64_t> outputShape = outputType.getShape();
256  if (outputShape[1] != 1 || outputShape[2] != 1) {
257  return failure();
258  }
259 
260  ArrayRef<int64_t> inputShape = inputType.getShape();
261  if (inputShape[1] != 1 || inputShape[2] != 1) {
262  return failure();
263  }
264 
265  rewriter.replaceOp(op, input);
266  return success();
267  }
268 };
269 
270 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
271  MLIRContext *context) {
272  results.add<MaxPool2dIsNoOp>(context);
273 }
274 
275 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
277 
278  LogicalResult matchAndRewrite(tosa::ClampOp op,
279  PatternRewriter &rewriter) const override {
280  Value input = op.getInput();
281  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
282  auto inputElementType = inputType.getElementType();
283 
284  if (!inputType.hasStaticShape()) {
285  return failure();
286  }
287 
288  if (isa<FloatType>(inputElementType)) {
289  // Unlike integer types, floating point types can represent infinity.
290  auto minClamp = op.getMinFp();
291  auto maxClamp = op.getMaxFp();
292  bool isMin = minClamp.isInfinity() && minClamp.isNegative();
293  bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
294 
295  if (isMin && isMax) {
296  rewriter.replaceOp(op, input);
297  return success();
298  }
299  return failure();
300  }
301 
302  if (inputElementType.isUnsignedInteger()) {
303  int64_t minClamp = op.getMinInt();
304  int64_t maxClamp = op.getMaxInt();
305 
306  int64_t intMin =
307  APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
308  .getZExtValue();
309  int64_t intMax =
310  APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
311  .getZExtValue();
312 
313  if (minClamp <= intMin && maxClamp >= intMax) {
314  rewriter.replaceOp(op, input);
315  return success();
316  }
317  return failure();
318  }
319 
320  if (llvm::isa<IntegerType>(inputElementType)) {
321  int64_t minClamp = op.getMinInt();
322  int64_t maxClamp = op.getMaxInt();
323 
324  int64_t intMin =
325  APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
326  .getSExtValue();
327  int64_t intMax =
328  APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
329  .getSExtValue();
330 
331  if (minClamp <= intMin && maxClamp >= intMax) {
332  rewriter.replaceOp(op, input);
333  return success();
334  }
335  return failure();
336  }
337 
338  return failure();
339  }
340 };
341 
342 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
344 
345  LogicalResult matchAndRewrite(tosa::ClampOp op,
346  PatternRewriter &rewriter) const override {
347  Value input = op.getInput();
348 
349  Operation *definingOp = input.getDefiningOp();
350  if (!definingOp)
351  return failure();
352 
353  if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354  auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
355  auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
356 
357  auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
358  auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
359 
360  rewriter.replaceOpWithNewOp<tosa::ClampOp>(
361  op, op.getType(), clampOp.getInput(),
362  rewriter.getI64IntegerAttr(minInt),
363  rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
364  rewriter.getF32FloatAttr(maxFp));
365  return success();
366  }
367 
368  return failure();
369  }
370 };
371 
372 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
373  MLIRContext *context) {
374  results.add<ClampIsNoOp>(context);
375  results.add<ClampClampOptimization>(context);
376 }
377 
378 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
380 
381  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
382  PatternRewriter &rewriter) const override {
383  Value sliceInput = sliceOp.getInput();
384  auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
385  if (!concatOp)
386  return rewriter.notifyMatchFailure(
387  sliceOp, "slice input must be concat operation");
388 
389  OperandRange inputs = concatOp.getInput1();
390  auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
391  if (!concatType || !concatType.hasStaticShape())
392  return rewriter.notifyMatchFailure(
393  sliceOp, "slice input must be a static ranked tensor");
394  int32_t axis = concatOp.getAxis();
395 
396  llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
397  llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
398 
399  // Validate slice on the concatenated axis. Slicing along this
400  // axis should span only one of the inputs to the concatenate
401  // operation.
402  std::optional<Value> replaceWithSlice;
403  for (auto input : inputs) {
404  auto inputType = dyn_cast<RankedTensorType>(input.getType());
405  if (!inputType || !inputType.hasStaticShape())
406  return rewriter.notifyMatchFailure(
407  sliceOp, "concat input must be a static ranked tensor");
408 
409  if (sliceStart[axis] >= 0 &&
410  (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
411  replaceWithSlice = rewriter
412  .create<tosa::SliceOp>(
413  sliceOp.getLoc(), sliceOp.getType(), input,
414  rewriter.getDenseI64ArrayAttr(sliceStart),
415  rewriter.getDenseI64ArrayAttr(sliceSize))
416  .getResult();
417  break;
418  }
419  sliceStart[axis] -= inputType.getDimSize(axis);
420  }
421 
422  if (!replaceWithSlice)
423  return rewriter.notifyMatchFailure(
424  sliceOp, "corresponding concat input not found for slice");
425 
426  rewriter.replaceOp(sliceOp, replaceWithSlice.value());
427  return success();
428  }
429 };
430 
431 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
432  MLIRContext *context) {
433  results.add<ConcatSliceOptimization>(context);
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Operator Folders.
438 //===----------------------------------------------------------------------===//
439 
440 template <typename IntFolder, typename FloatFolder>
442  RankedTensorType returnTy) {
443  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
444  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
445  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
446  if (lETy != rETy)
447  return {};
448 
449  if (llvm::isa<IntegerType>(lETy)) {
450  APInt l = lhs.getSplatValue<APInt>();
451  APInt r = rhs.getSplatValue<APInt>();
452  auto result = IntFolder()(l, r);
453  return DenseElementsAttr::get(returnTy, result);
454  }
455 
456  if (llvm::isa<FloatType>(lETy)) {
457  APFloat l = lhs.getSplatValue<APFloat>();
458  APFloat r = rhs.getSplatValue<APFloat>();
459  auto result = FloatFolder()(l, r);
460  return DenseElementsAttr::get(returnTy, result);
461  }
462  }
463 
464  return {};
465 }
466 
467 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
468  if (llvm::isa<FloatType>(elemType))
469  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
470  if (llvm::isa<IntegerType>(elemType))
471  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
472  return false;
473 }
474 
475 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
476  if (llvm::isa<FloatType>(elemType))
477  return val && val.isSplat() &&
478  val.getSplatValue<APFloat>().isExactlyValue(1.0);
479  if (llvm::isa<IntegerType>(elemType)) {
480  const int64_t shifted = 1LL << shift;
481  return val && val.isSplat() &&
482  val.getSplatValue<APInt>().getSExtValue() == shifted;
483  }
484  return false;
485 }
486 
487 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
488  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
489  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
490  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
491  if (!lhsTy || !rhsTy || !resultTy)
492  return {};
493 
494  auto resultETy = resultTy.getElementType();
495  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
496  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
497 
498  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
499  return getInput1();
500  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
501  return getInput2();
502 
503  if (!lhsAttr || !rhsAttr)
504  return {};
505 
506  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
507  resultTy);
508 }
509 
510 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
511  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
512  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
513  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
514  !outputTy.hasStaticShape())
515  return {};
516 
517  if (inputTy.getDimSize(getAxis()) == 1)
518  return DenseElementsAttr::get(outputTy, 0);
519 
520  return {};
521 }
522 
523 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
524  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
525  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
526  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
527  if (!lhsTy || !rhsTy || !resultTy)
528  return {};
529  if (lhsTy != rhsTy)
530  return {};
531 
532  auto resultETy = resultTy.getElementType();
533  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
534  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
535  if (lhsAttr && lhsAttr.isSplat()) {
536  if (llvm::isa<IntegerType>(resultETy) &&
537  lhsAttr.getSplatValue<APInt>().isZero())
538  return lhsAttr;
539  }
540 
541  if (rhsAttr && rhsAttr.isSplat()) {
542  if (llvm::isa<IntegerType>(resultETy) &&
543  rhsAttr.getSplatValue<APInt>().isOne())
544  return getInput1();
545  }
546 
547  if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
548  if (llvm::isa<IntegerType>(resultETy)) {
549  APInt l = lhsAttr.getSplatValue<APInt>();
550  APInt r = rhsAttr.getSplatValue<APInt>();
551  APInt result = l.sdiv(r);
552  return DenseElementsAttr::get(resultTy, result);
553  }
554  }
555 
556  return {};
557 }
558 
559 namespace {
560 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
561  RankedTensorType ty, int32_t shift) {
562  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
563  if (llvm::isa<IntegerType>(ty.getElementType())) {
564  APInt l = lhs.getSplatValue<APInt>();
565  APInt r = rhs.getSplatValue<APInt>();
566 
567  if (shift == 0) {
568  return DenseElementsAttr::get(ty, l * r);
569  }
570 
571  auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
572  l = l.sext(bitwidth * 2);
573  r = r.sext(bitwidth * 2);
574  auto result = l * r;
575  result.lshrInPlace(shift);
576  result = result.trunc(bitwidth);
577  return DenseElementsAttr::get(ty, result);
578  }
579 
580  if (llvm::isa<FloatType>(ty.getElementType())) {
581  APFloat l = lhs.getSplatValue<APFloat>();
582  APFloat r = rhs.getSplatValue<APFloat>();
583  APFloat result = l * r;
584  return DenseElementsAttr::get(ty, result);
585  }
586  }
587 
588  return {};
589 }
590 } // namespace
591 
592 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
593  auto lhs = getInput1();
594  auto rhs = getInput2();
595  auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
596  auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
597  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
598  if (!lhsTy || !rhsTy || !resultTy)
599  return {};
600 
601  auto resultETy = resultTy.getElementType();
602  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
603  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
604 
605  const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
606  if (rhsTy == resultTy) {
607  if (isSplatZero(resultETy, lhsAttr))
608  return lhsAttr.resizeSplat(resultTy);
609  if (isSplatOne(resultETy, lhsAttr, shift))
610  return rhs;
611  }
612  if (lhsTy == resultTy) {
613  if (isSplatZero(resultETy, rhsAttr))
614  return rhsAttr.resizeSplat(resultTy);
615  if (isSplatOne(resultETy, rhsAttr, shift))
616  return lhs;
617  }
618 
619  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
620 }
621 
622 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
623  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
624  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
625  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
626  if (!lhsTy || !rhsTy || !resultTy)
627  return {};
628 
629  auto resultETy = resultTy.getElementType();
630  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
631  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
632 
633  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
634  return getInput1();
635 
636  if (!lhsAttr || !rhsAttr)
637  return {};
638 
639  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
640  resultTy);
641 }
642 
643 namespace {
644 template <typename Cmp>
645 struct ComparisonFold {
646  ComparisonFold() = default;
647  APInt operator()(const APInt &l, const APInt &r) {
648  return APInt(1, Cmp()(l, r));
649  }
650 
651  APInt operator()(const APFloat &l, const APFloat &r) {
652  return APInt(1, Cmp()(l, r));
653  }
654 };
655 
656 struct APIntFoldGreater {
657  APIntFoldGreater() = default;
658  APInt operator()(const APInt &l, const APInt &r) {
659  return APInt(1, l.sgt(r));
660  }
661 };
662 
663 struct APIntFoldGreaterEqual {
664  APIntFoldGreaterEqual() = default;
665  APInt operator()(const APInt &l, const APInt &r) {
666  return APInt(1, l.sge(r));
667  }
668 };
669 } // namespace
670 
671 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
672  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
673  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
674  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
675 
676  if (!lhsAttr || !rhsAttr)
677  return {};
678 
679  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
680  lhsAttr, rhsAttr, resultTy);
681 }
682 
683 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
684  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
685  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
686  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
687 
688  if (!lhsAttr || !rhsAttr)
689  return {};
690 
691  return binaryFolder<APIntFoldGreaterEqual,
692  ComparisonFold<std::greater_equal<APFloat>>>(
693  lhsAttr, rhsAttr, resultTy);
694 }
695 
696 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
697  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
698  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
699  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
700  Value lhs = getInput1();
701  Value rhs = getInput2();
702  auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
703 
704  // If we are comparing an integer value to itself it is always true. We can
705  // not do this with float due to float values.
706  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
707  resultTy.hasStaticShape() && lhs == rhs) {
708  return DenseElementsAttr::get(resultTy, true);
709  }
710 
711  if (!lhsAttr || !rhsAttr)
712  return {};
713 
714  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
715  ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
716  resultTy);
717 }
718 
719 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
720  if (getInput().getType() == getType())
721  return getInput();
722 
723  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
724  if (!operand)
725  return {};
726 
727  auto inTy = llvm::cast<ShapedType>(getInput().getType());
728  auto outTy = llvm::cast<ShapedType>(getType());
729  auto inETy = inTy.getElementType();
730  auto outETy = outTy.getElementType();
731 
732  if (operand.isSplat()) {
733  if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
734  bool overflow;
735  auto splatVal = operand.getSplatValue<APFloat>();
736  auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
737  splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
738  &overflow);
739  return SplatElementsAttr::get(outTy, splatVal);
740  }
741 
742  if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
743  auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
744  APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
745  splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
746  llvm::RoundingMode::NearestTiesToEven);
747  return SplatElementsAttr::get(outTy, splatVal);
748  }
749 
750  if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
751  auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
752  auto intVal = APSInt(
753  llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
754  auto floatVal = operand.getSplatValue<APFloat>();
755  bool exact;
756  floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
757  return SplatElementsAttr::get(outTy, intVal);
758  }
759 
760  if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
761  auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
762  bool trunc =
763  inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
764  auto intVal = operand.getSplatValue<APInt>();
765  auto bitwidth = outETy.getIntOrFloatBitWidth();
766 
767  if (trunc) {
768  intVal = intVal.trunc(bitwidth);
769  } else if (unsignIn) {
770  intVal = intVal.zext(bitwidth);
771  } else {
772  intVal = intVal.sext(bitwidth);
773  }
774 
775  return SplatElementsAttr::get(outTy, intVal);
776  }
777  }
778 
779  return {};
780 }
781 
782 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
783 
784 #define REDUCE_FOLDER(OP) \
785  OpFoldResult OP::fold(FoldAdaptor adaptor) { \
786  ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
787  if (!inputTy.hasRank()) \
788  return {}; \
789  if (inputTy != getType()) \
790  return {}; \
791  if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
792  return getInput(); \
793  return {}; \
794  }
795 
796 REDUCE_FOLDER(ReduceAllOp)
797 REDUCE_FOLDER(ReduceAnyOp)
798 REDUCE_FOLDER(ReduceMaxOp)
799 REDUCE_FOLDER(ReduceMinOp)
800 REDUCE_FOLDER(ReduceProdOp)
801 REDUCE_FOLDER(ReduceSumOp)
802 #undef REDUCE_FOLDER
803 
804 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
805  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
806  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
807 
808  if (!inputTy || !outputTy)
809  return {};
810 
811  // Fold when the input and output types are the same. This is only safe when
812  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
813  // there may still be a productive reshape.
814  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
815  return getInput1();
816 
817  // reshape(reshape(x)) -> reshape(x)
818  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
819  getInput1().getDefiningOp())) {
820  getInput1Mutable().assign(reshapeOp.getInput1());
821  return getResult();
822  }
823 
824  // reshape(const(x)) -> const(reshape-attr(x))
825  if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
826  // Constants must have static shape.
827  if (!outputTy.hasStaticShape())
828  return {};
829 
830  // Okay to duplicate splat constants.
831  if (operand.isSplat())
832  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
833 
834  // Don't duplicate other constants.
835  if (!getInput1().hasOneUse())
836  return {};
837 
838  return operand.reshape(
839  llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
840  }
841 
842  return {};
843 }
844 
845 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
846  // If the pad is all zeros we can fold this operation away.
847  if (adaptor.getPadding()) {
848  auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
849  if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
850  return getInput1();
851  }
852  }
853 
854  return {};
855 }
856 
857 // Fold away cases where a tosa.resize operation returns a copy
858 // of the input image.
859 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
860  ArrayRef<int64_t> offset = getOffset();
861  ArrayRef<int64_t> border = getBorder();
862  ArrayRef<int64_t> scale = getScale();
863 
864  // Check unit scaling.
865  if (scale[0] != scale[1] || scale[2] != scale[3]) {
866  return {};
867  }
868 
869  // There should be no offset.
870  if (offset[0] != 0 || offset[1] != 0) {
871  return {};
872  }
873 
874  // There should be no border.
875  if (border[0] != 0 || border[1] != 0) {
876  return {};
877  }
878 
879  auto input = getInput();
880  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
881  auto resultTy = llvm::cast<RankedTensorType>(getType());
882  if (inputTy != resultTy)
883  return {};
884 
885  return input;
886 }
887 
888 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
889  auto operand = getInput();
890  auto operandTy = llvm::cast<ShapedType>(operand.getType());
891  auto axis = getAxis();
892  auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
893  if (operandAttr)
894  return operandAttr;
895 
896  // If the dim-length is 1, tosa.reverse is a no-op.
897  if (operandTy.hasRank() &&
898  (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
899  return operand;
900 
901  return {};
902 }
903 
904 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
905  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
906  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
907 
908  if (!inputTy || !outputTy)
909  return {};
910 
911  if (inputTy == outputTy && inputTy.hasStaticShape())
912  return getInput();
913 
914  if (!adaptor.getInput())
915  return {};
916 
917  // Cannot create an ElementsAttr from non-int/float/index types
918  if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
919  !outputTy.getElementType().isIntOrIndexOrFloat())
920  return {};
921 
922  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
923  if (operand.isSplat() && outputTy.hasStaticShape()) {
924  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
925  }
926 
927  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
928  outputTy.getNumElements() == 1) {
929  llvm::SmallVector<uint64_t> indices(getStart());
930  auto value = operand.getValues<Attribute>()[indices];
931  return SplatElementsAttr::get(outputTy, value);
932  }
933 
934  return {};
935 }
936 
937 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
938  if (getOnTrue() == getOnFalse())
939  return getOnTrue();
940 
941  auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
942  if (!predicate)
943  return {};
944 
945  if (!predicate.isSplat())
946  return {};
947  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
948  : getOnFalse();
949 }
950 
951 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
952  bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
953  if (allOnes && getInput1().getType() == getType())
954  return getInput1();
955  return {};
956 }
957 
958 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
959  auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
960  auto resultTy = llvm::cast<ShapedType>(getType());
961 
962  // Transposing splat values just means reshaping.
963  if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
964  if (input.isSplat() && resultTy.hasStaticShape() &&
965  inputTy.getElementType() == resultTy.getElementType())
966  return input.reshape(resultTy);
967  }
968 
969  // Transpose does not change the input type.
970  if (getInput1().getType() != getType())
971  return {};
972 
973  // Transpose is not the identity transpose.
974  SmallVector<int64_t> perms;
975  if (getConstantPerms(perms).failed())
976  return {};
977 
978  if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
979  return {};
980 
981  return getInput1();
982 }
983 
984 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
985  auto input = getInput1();
986  // Element-wise log(exp(x)) = x
987  if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
988  return op.getInput1();
989  }
990 
991  return {};
992 }
993 
994 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
995  auto input = getInput1();
996  // Element-wise exp(log(x)) = x
997  if (auto op = input.getDefiningOp<tosa::LogOp>()) {
998  return op.getInput1();
999  }
1000 
1001  return {};
1002 }
1003 
1004 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1005  auto input = getInput1();
1006  // Element-wise negate(negate(x)) = x
1007  if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
1008  return op.getInput1();
1009  }
1010 
1011  return {};
1012 }
1013 
1014 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1015  auto input = getInput1();
1016  // Element-wise abs(abs(x)) = abs(x)
1017  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1018  return input;
1019  }
1020 
1021  return {};
1022 }
1023 
1024 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1025  // Fold consecutive concats on the same axis into a single op.
1026  // Keep track of the operands so we are able to construct a new concat
1027  // later. Conservatively assume that we double the number of operands when
1028  // folding
1029  SmallVector<Value, 8> concatOperands;
1030  concatOperands.reserve(2 * getNumOperands());
1031 
1032  // Find all operands that are foldable concats
1033  bool foundFoldableConcat = false;
1034  for (Value operand : getOperands()) {
1035  concatOperands.emplace_back(operand);
1036 
1037  auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1038  if (!producer)
1039  continue;
1040 
1041  // Not foldable if axes are not the same
1042  if (getAxis() != producer.getAxis())
1043  continue;
1044 
1045  // Replace the original operand with all incoming operands
1046  foundFoldableConcat = true;
1047  concatOperands.pop_back();
1048  llvm::append_range(concatOperands, producer->getOperands());
1049  }
1050 
1051  if (!foundFoldableConcat)
1052  return {};
1053 
1054  getOperation()->setOperands(concatOperands);
1055  return getResult();
1056 }
1057 
1058 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1059  auto input = adaptor.getInput1();
1060 
1061  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1062  // Fold splat inputs only.
1063  if (!inputAttr || !inputAttr.isSplat())
1064  return {};
1065 
1066  auto shapeType = llvm::cast<ShapedType>(getType());
1067  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1068  auto floatVal = inputAttr.getSplatValue<APFloat>();
1069  return DenseElementsAttr::get(shapeType,
1070  ReciprocalOp::calcOneElement(floatVal));
1071  }
1072 
1073  return {};
1074 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:253
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362