MLIR  18.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // TOSA canonicalization patterns and folders.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 
31 #include <functional>
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 //===----------------------------------------------------------------------===//
37 // Operator Canonicalizers.
38 //===----------------------------------------------------------------------===//
39 
40 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
42 
43  LogicalResult matchAndRewrite(tosa::ConcatOp op,
44  PatternRewriter &rewriter) const override {
45  if (op.getInput1().size() != 1)
46  return failure();
47  if (op.getInput1().front().getType() != op.getType()) {
48  rewriter
49  .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
50  op.getInput1().front())
51  .getResult();
52  return success();
53  }
54 
55  rewriter.replaceOp(op, op.getInput1().front());
56  return success();
57  }
58 };
59 
60 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
61  MLIRContext *context) {
62  results.add<ConcatOptimization>(context);
63 }
64 
65 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
66  auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
67  if (!notOp)
68  return failure();
69  rewriter.updateRootInPlace(op, [&]() {
70  op.getOperation()->setOperands(
71  {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
72  });
73  return success();
74 }
75 
77  : public OpRewritePattern<tosa::TransposeOp> {
79 
80  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
81  PatternRewriter &rewriter) const override {
82  // Input is also TransposeOp - transpose(transpose(A)).
83  auto innerTranspose =
84  transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
85  if (!innerTranspose)
86  return rewriter.notifyMatchFailure(transposeOp,
87  "input must be transpose operation");
88 
89  SmallVector<int64_t> transposePerms, innerTransposePerms;
90  if (transposeOp.getConstantPerms(transposePerms).failed())
91  return rewriter.notifyMatchFailure(transposeOp,
92  "transpose perms must be constant");
93  if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
94  return rewriter.notifyMatchFailure(
95  transposeOp, "inner transpose perms must be constant");
96  if (transposePerms.size() != innerTransposePerms.size())
97  return rewriter.notifyMatchFailure(
98  transposeOp,
99  "transpose and inner transpose perms sizes must be equal");
100  if (transposePerms.empty())
101  return rewriter.notifyMatchFailure(
102  transposeOp, "transpose perms sizes must be positive");
103 
104  // Consolidate transposes into one transpose.
105  SmallVector<int32_t> perms(transposePerms.size());
106  for (int i = 0, s = transposePerms.size(); i < s; ++i)
107  perms[i] = innerTransposePerms[transposePerms[i]];
108 
109  auto permsTy =
110  RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
111  auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
112  Value permsValue =
113  rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
114 
115  rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
116  transposeOp, transposeOp.getResult().getType(),
117  innerTranspose.getInput1(), permsValue);
118 
119  return success();
120  }
121 };
122 
123 // Determines the case when tosa.transpose is a tosa.reshape operation.
124 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
126 
127  LogicalResult matchAndRewrite(tosa::TransposeOp op,
128  PatternRewriter &rewriter) const override {
129  DenseIntElementsAttr permAttr;
130  if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
131  return rewriter.notifyMatchFailure(op, "Non-constant permutation");
132 
133  if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
134  return rewriter.notifyMatchFailure(
135  op, "Src is from transpose, can compose transposes");
136 
137  Value result = op.getResult();
138  for (Operation *subop : result.getUsers()) {
139  if (dyn_cast_or_null<tosa::TransposeOp>(subop))
140  return rewriter.notifyMatchFailure(
141  op, "Dest is used by transpose, can compose transposes");
142  }
143 
144  auto input = op.getInput1();
145  auto inputTy = llvm::cast<ShapedType>(input.getType());
146  if (!inputTy.hasRank())
147  return rewriter.notifyMatchFailure(op, "Unranked input.");
148 
149  int64_t numDynDims = 0;
150  for (int i = 0; i < inputTy.getRank(); ++i)
151  if (inputTy.isDynamicDim(i))
152  numDynDims++;
153 
154  if (numDynDims > 1)
155  return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
156 
157  SmallVector<int64_t> permValues = llvm::to_vector<6>(
158  llvm::map_range(permAttr.getValues<APInt>(),
159  [](const APInt &val) { return val.getSExtValue(); }));
160 
161  SmallVector<int64_t> nonZeroPerms;
162  nonZeroPerms.reserve(permValues.size());
163  for (auto idx : permValues) {
164  auto sz = inputTy.getDimSize(idx);
165  if (sz != 1)
166  nonZeroPerms.push_back(idx);
167  }
168 
169  for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
170  if (nonZeroPerms[i - 1] > nonZeroPerms[i])
171  return rewriter.notifyMatchFailure(op,
172  "Transpose changes memory layout.");
173 
174  SmallVector<int64_t> newShape;
175  newShape.reserve(inputTy.getRank());
176  for (int i = 0, s = inputTy.getRank(); i < s; ++i)
177  newShape.push_back(inputTy.getDimSize(permValues[i]));
178 
179  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
180  op, op.getType(), op.getInput1(),
181  rewriter.getDenseI64ArrayAttr(newShape));
182  return success();
183  }
184 };
185 
186 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
187  MLIRContext *context) {
189 }
190 
191 struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
193 
195  PatternRewriter &rewriter) const override {
196  if (op.getPadConst())
197  return failure();
198 
199  auto input = op.getInput1();
200  auto padding = op.getPadding();
201 
202  ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
203  Type elementTy = inputTy.getElementType();
204 
205  Attribute constantAttr;
206  if (llvm::isa<FloatType>(elementTy)) {
207  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
208  } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
209  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
210  } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
211  auto value = op.getQuantizationInfo()->getInputZp();
212  constantAttr = rewriter.getIntegerAttr(elementTy, value);
213  }
214 
215  if (!constantAttr) {
216  return rewriter.notifyMatchFailure(
217  op,
218  "tosa.pad to linalg lowering encountered an unknown element type");
219  }
220 
221  auto denseAttr = DenseElementsAttr::get(
222  RankedTensorType::get({}, elementTy), constantAttr);
223  auto constantVal = rewriter.create<tosa::ConstOp>(
224  op.getLoc(), denseAttr.getType(), denseAttr);
225 
226  rewriter.replaceOpWithNewOp<tosa::PadOp>(
227  op, op.getType(), ValueRange{input, padding, constantVal},
228  op->getAttrs());
229  return success();
230  }
231 };
232 
233 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
234  MLIRContext *context) {
235  results.add<MaterializePadValue>(context);
236 }
237 
238 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
240 
241  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
242  PatternRewriter &rewriter) const override {
243  Value input = op.getInput();
244  Value output = op.getOutput();
245  ShapedType inputType = llvm::cast<ShapedType>(input.getType());
246  ShapedType outputType = llvm::cast<ShapedType>(output.getType());
247 
248  if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
249  return failure();
250  }
251 
252  // If the output and input shapes are 1x1, then this is a no op.
253  ArrayRef<int64_t> outputShape = outputType.getShape();
254  if (outputShape[1] != 1 || outputShape[2] != 1) {
255  return failure();
256  }
257 
258  ArrayRef<int64_t> inputShape = inputType.getShape();
259  if (inputShape[1] != 1 || inputShape[2] != 1) {
260  return failure();
261  }
262 
263  rewriter.replaceOp(op, input);
264  return success();
265  }
266 };
267 
268 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
269  MLIRContext *context) {
270  results.add<MaxPool2dIsNoOp>(context);
271 }
272 
273 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
275 
276  LogicalResult matchAndRewrite(tosa::ClampOp op,
277  PatternRewriter &rewriter) const override {
278  Value input = op.getInput();
279  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
280  auto inputElementType = inputType.getElementType();
281 
282  if (!inputType.hasStaticShape()) {
283  return failure();
284  }
285 
286  if (inputElementType.isa<FloatType>()) {
287  // Unlike integer types, floating point types can represent infinity.
288  auto minClamp = op.getMinFp();
289  auto maxClamp = op.getMaxFp();
290  bool isMin = minClamp.isInfinity() && minClamp.isNegative();
291  bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
292 
293  if (isMin && isMax) {
294  rewriter.replaceOp(op, input);
295  return success();
296  }
297  return failure();
298  }
299 
300  if (inputElementType.isUnsignedInteger()) {
301  int64_t minClamp = op.getMinInt();
302  int64_t maxClamp = op.getMaxInt();
303 
304  int64_t intMin =
305  APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
306  .getZExtValue();
307  int64_t intMax =
308  APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
309  .getZExtValue();
310 
311  if (minClamp <= intMin && maxClamp >= intMax) {
312  rewriter.replaceOp(op, input);
313  return success();
314  }
315  return failure();
316  }
317 
318  if (llvm::isa<IntegerType>(inputElementType)) {
319  int64_t minClamp = op.getMinInt();
320  int64_t maxClamp = op.getMaxInt();
321 
322  int64_t intMin =
323  APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
324  .getSExtValue();
325  int64_t intMax =
326  APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
327  .getSExtValue();
328 
329  if (minClamp <= intMin && maxClamp >= intMax) {
330  rewriter.replaceOp(op, input);
331  return success();
332  }
333  return failure();
334  }
335 
336  return failure();
337  }
338 };
339 
340 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
342 
343  LogicalResult matchAndRewrite(tosa::ClampOp op,
344  PatternRewriter &rewriter) const override {
345  Value input = op.getInput();
346 
347  Operation *definingOp = input.getDefiningOp();
348  if (!definingOp)
349  return failure();
350 
351  if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
352  auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
353  auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
354 
355  auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
356  auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
357 
358  rewriter.replaceOpWithNewOp<tosa::ClampOp>(
359  op, op.getType(), clampOp.getInput(),
360  rewriter.getI64IntegerAttr(minInt),
361  rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
362  rewriter.getF32FloatAttr(maxFp));
363  return success();
364  }
365 
366  return failure();
367  }
368 };
369 
370 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
371  MLIRContext *context) {
372  results.add<ClampIsNoOp>(context);
373  results.add<ClampClampOptimization>(context);
374 }
375 
376 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
378 
379  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
380  PatternRewriter &rewriter) const override {
381  Value sliceInput = sliceOp.getInput();
382  auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
383  if (!concatOp)
384  return rewriter.notifyMatchFailure(
385  sliceOp, "slice input must be concat operation");
386 
387  OperandRange inputs = concatOp.getInput1();
388  auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
389  if (!concatType || !concatType.hasStaticShape())
390  return rewriter.notifyMatchFailure(
391  sliceOp, "slice input must be a static ranked tensor");
392  int32_t axis = concatOp.getAxis();
393 
394  llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
395  llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
396 
397  // Validate slice on the concatenated axis. Slicing along this
398  // axis should span only one of the inputs to the concatenate
399  // operation.
400  std::optional<Value> replaceWithSlice;
401  for (auto input : inputs) {
402  auto inputType = dyn_cast<RankedTensorType>(input.getType());
403  if (!inputType || !inputType.hasStaticShape())
404  return rewriter.notifyMatchFailure(
405  sliceOp, "concat input must be a static ranked tensor");
406 
407  if (sliceStart[axis] >= 0 &&
408  (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
409  replaceWithSlice = rewriter
410  .create<tosa::SliceOp>(
411  sliceOp.getLoc(), sliceOp.getType(), input,
412  rewriter.getDenseI64ArrayAttr(sliceStart),
413  rewriter.getDenseI64ArrayAttr(sliceSize))
414  .getResult();
415  break;
416  }
417  sliceStart[axis] -= inputType.getDimSize(axis);
418  }
419 
420  if (!replaceWithSlice)
421  return rewriter.notifyMatchFailure(
422  sliceOp, "corresponding concat input not found for slice");
423 
424  rewriter.replaceOp(sliceOp, replaceWithSlice.value());
425  return success();
426  }
427 };
428 
429 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
430  MLIRContext *context) {
431  results.add<ConcatSliceOptimization>(context);
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // Operator Folders.
436 //===----------------------------------------------------------------------===//
437 
438 template <typename IntFolder, typename FloatFolder>
440  RankedTensorType returnTy) {
441  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
442  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
443  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
444  if (lETy != rETy)
445  return {};
446 
447  if (llvm::isa<IntegerType>(lETy)) {
448  APInt l = lhs.getSplatValue<APInt>();
449  APInt r = rhs.getSplatValue<APInt>();
450  auto result = IntFolder()(l, r);
451  return DenseElementsAttr::get(returnTy, result);
452  }
453 
454  if (llvm::isa<FloatType>(lETy)) {
455  APFloat l = lhs.getSplatValue<APFloat>();
456  APFloat r = rhs.getSplatValue<APFloat>();
457  auto result = FloatFolder()(l, r);
458  return DenseElementsAttr::get(returnTy, result);
459  }
460  }
461 
462  return {};
463 }
464 
465 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
466  if (llvm::isa<FloatType>(elemType))
467  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
468  if (llvm::isa<IntegerType>(elemType))
469  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
470  return false;
471 }
472 
473 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
474  if (llvm::isa<FloatType>(elemType))
475  return val && val.isSplat() &&
476  val.getSplatValue<APFloat>().isExactlyValue(1.0);
477  if (llvm::isa<IntegerType>(elemType)) {
478  const int64_t shifted = 1LL << shift;
479  return val && val.isSplat() &&
480  val.getSplatValue<APInt>().getSExtValue() == shifted;
481  }
482  return false;
483 }
484 
485 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
486  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
487  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
488  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
489  if (!lhsTy || !rhsTy || !resultTy)
490  return {};
491 
492  auto resultETy = resultTy.getElementType();
493  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
494  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
495 
496  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
497  return getInput1();
498  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
499  return getInput2();
500 
501  if (!lhsAttr || !rhsAttr)
502  return {};
503 
504  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
505  resultTy);
506 }
507 
508 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
509  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
510  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
511  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
512  if (!lhsTy || !rhsTy || !resultTy)
513  return {};
514  if (lhsTy != rhsTy)
515  return {};
516 
517  auto resultETy = resultTy.getElementType();
518  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
519  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
520  if (lhsAttr && lhsAttr.isSplat()) {
521  if (llvm::isa<IntegerType>(resultETy) &&
522  lhsAttr.getSplatValue<APInt>().isZero())
523  return lhsAttr;
524  }
525 
526  if (rhsAttr && rhsAttr.isSplat()) {
527  if (llvm::isa<IntegerType>(resultETy) &&
528  rhsAttr.getSplatValue<APInt>().isOne())
529  return getInput1();
530  }
531 
532  if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
533  if (llvm::isa<IntegerType>(resultETy)) {
534  APInt l = lhsAttr.getSplatValue<APInt>();
535  APInt r = rhsAttr.getSplatValue<APInt>();
536  APInt result = l.sdiv(r);
537  return DenseElementsAttr::get(resultTy, result);
538  }
539  }
540 
541  return {};
542 }
543 
544 namespace {
545 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
546  RankedTensorType ty, int32_t shift) {
547  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
548  if (llvm::isa<IntegerType>(ty.getElementType())) {
549  APInt l = lhs.getSplatValue<APInt>();
550  APInt r = rhs.getSplatValue<APInt>();
551 
552  if (shift == 0) {
553  return DenseElementsAttr::get(ty, l * r);
554  }
555 
556  auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
557  l = l.sext(bitwidth * 2);
558  r = r.sext(bitwidth * 2);
559  auto result = l * r;
560  result.lshrInPlace(shift);
561  result = result.trunc(bitwidth);
562  return DenseElementsAttr::get(ty, result);
563  }
564 
565  if (llvm::isa<FloatType>(ty.getElementType())) {
566  APFloat l = lhs.getSplatValue<APFloat>();
567  APFloat r = rhs.getSplatValue<APFloat>();
568  APFloat result = l * r;
569  return DenseElementsAttr::get(ty, result);
570  }
571  }
572 
573  return {};
574 }
575 } // namespace
576 
577 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
578  auto lhs = getInput1();
579  auto rhs = getInput2();
580  auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
581  auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
582  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
583  if (!lhsTy || !rhsTy || !resultTy)
584  return {};
585 
586  auto resultETy = resultTy.getElementType();
587  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
588  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
589 
590  const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
591  if (rhsTy == resultTy) {
592  if (isSplatZero(resultETy, lhsAttr))
593  return lhsAttr.resizeSplat(resultTy);
594  if (isSplatOne(resultETy, lhsAttr, shift))
595  return rhs;
596  }
597  if (lhsTy == resultTy) {
598  if (isSplatZero(resultETy, rhsAttr))
599  return rhsAttr.resizeSplat(resultTy);
600  if (isSplatOne(resultETy, rhsAttr, shift))
601  return lhs;
602  }
603 
604  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
605 }
606 
607 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
608  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
609  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
610  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
611  if (!lhsTy || !rhsTy || !resultTy)
612  return {};
613 
614  auto resultETy = resultTy.getElementType();
615  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
616  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
617 
618  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
619  return getInput1();
620 
621  if (!lhsAttr || !rhsAttr)
622  return {};
623 
624  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
625  resultTy);
626 }
627 
628 namespace {
629 template <typename Cmp>
630 struct ComparisonFold {
631  ComparisonFold() = default;
632  APInt operator()(const APInt &l, const APInt &r) {
633  return APInt(1, Cmp()(l, r));
634  }
635 
636  APInt operator()(const APFloat &l, const APFloat &r) {
637  return APInt(1, Cmp()(l, r));
638  }
639 };
640 
641 struct APIntFoldGreater {
642  APIntFoldGreater() = default;
643  APInt operator()(const APInt &l, const APInt &r) {
644  return APInt(1, l.sgt(r));
645  }
646 };
647 
648 struct APIntFoldGreaterEqual {
649  APIntFoldGreaterEqual() = default;
650  APInt operator()(const APInt &l, const APInt &r) {
651  return APInt(1, l.sge(r));
652  }
653 };
654 } // namespace
655 
656 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
657  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
658  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
659  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
660 
661  if (!lhsAttr || !rhsAttr)
662  return {};
663 
664  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
665  lhsAttr, rhsAttr, resultTy);
666 }
667 
668 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
669  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
670  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
671  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
672 
673  if (!lhsAttr || !rhsAttr)
674  return {};
675 
676  return binaryFolder<APIntFoldGreaterEqual,
677  ComparisonFold<std::greater_equal<APFloat>>>(
678  lhsAttr, rhsAttr, resultTy);
679 }
680 
681 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
682  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
683  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
684  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
685  Value lhs = getInput1();
686  Value rhs = getInput2();
687  auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
688 
689  // If we are comparing an integer value to itself it is always true. We can
690  // not do this with float due to float values.
691  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
692  resultTy.hasStaticShape() && lhs == rhs) {
693  return DenseElementsAttr::get(resultTy, true);
694  }
695 
696  if (!lhsAttr || !rhsAttr)
697  return {};
698 
699  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
700  ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
701  resultTy);
702 }
703 
704 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
705  if (getInput().getType() == getType())
706  return getInput();
707 
708  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
709  if (!operand)
710  return {};
711 
712  auto inTy = llvm::cast<ShapedType>(getInput().getType());
713  auto outTy = llvm::cast<ShapedType>(getType());
714  auto inETy = inTy.getElementType();
715  auto outETy = outTy.getElementType();
716 
717  if (operand.isSplat()) {
718  if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
719  bool overflow;
720  auto splatVal = operand.getSplatValue<APFloat>();
721  auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
722  splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
723  &overflow);
724  return SplatElementsAttr::get(outTy, splatVal);
725  }
726 
727  if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
728  auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
729  APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
730  splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
731  llvm::RoundingMode::NearestTiesToEven);
732  return SplatElementsAttr::get(outTy, splatVal);
733  }
734 
735  if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
736  auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
737  auto intVal = APSInt(
738  llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
739  auto floatVal = operand.getSplatValue<APFloat>();
740  bool exact;
741  floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
742  return SplatElementsAttr::get(outTy, intVal);
743  }
744 
745  if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
746  auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
747  bool trunc =
748  inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
749  auto intVal = operand.getSplatValue<APInt>();
750  auto bitwidth = outETy.getIntOrFloatBitWidth();
751 
752  if (trunc) {
753  intVal = intVal.trunc(bitwidth);
754  } else if (unsignIn) {
755  intVal = intVal.zext(bitwidth);
756  } else {
757  intVal = intVal.sext(bitwidth);
758  }
759 
760  return SplatElementsAttr::get(outTy, intVal);
761  }
762  }
763 
764  return {};
765 }
766 
767 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
768 
769 #define REDUCE_FOLDER(OP) \
770  OpFoldResult OP::fold(FoldAdaptor adaptor) { \
771  ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
772  if (!inputTy.hasRank()) \
773  return {}; \
774  if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
775  return getInput(); \
776  return {}; \
777  }
778 
779 REDUCE_FOLDER(ReduceAllOp)
780 REDUCE_FOLDER(ReduceAnyOp)
781 REDUCE_FOLDER(ReduceMaxOp)
782 REDUCE_FOLDER(ReduceMinOp)
783 REDUCE_FOLDER(ReduceProdOp)
784 REDUCE_FOLDER(ReduceSumOp)
785 #undef REDUCE_FOLDER
786 
787 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
788  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
789  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
790 
791  if (!inputTy || !outputTy)
792  return {};
793 
794  if (inputTy == outputTy)
795  return getInput1();
796 
797  // reshape(reshape(x)) -> reshape(x)
798  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
799  getInput1().getDefiningOp())) {
800  getInput1Mutable().assign(reshapeOp.getInput1());
801  return getResult();
802  }
803 
804  // reshape(const(x)) -> const(reshape-attr(x))
805  if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
806  // Constants must have static shape.
807  if (!outputTy.hasStaticShape())
808  return {};
809 
810  // Okay to duplicate splat constants.
811  if (operand.isSplat())
812  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
813 
814  // Don't duplicate other constants.
815  if (!getInput1().hasOneUse())
816  return {};
817 
818  return operand.reshape(
819  llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
820  }
821 
822  return {};
823 }
824 
825 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
826  // If the pad is all zeros we can fold this operation away.
827  if (adaptor.getPadding()) {
828  auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
829  if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
830  return getInput1();
831  }
832  }
833 
834  return {};
835 }
836 
837 // Fold away cases where a tosa.resize operation returns a copy
838 // of the input image.
839 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
840  ArrayRef<int64_t> offset = getOffset();
841  ArrayRef<int64_t> border = getBorder();
842  ArrayRef<int64_t> scale = getScale();
843 
844  // Check unit scaling.
845  if (scale[0] != scale[1] || scale[2] != scale[3]) {
846  return {};
847  }
848 
849  // There should be no offset.
850  if (offset[0] != 0 || offset[1] != 0) {
851  return {};
852  }
853 
854  // There should be no border.
855  if (border[0] != 0 || border[1] != 0) {
856  return {};
857  }
858 
859  auto input = getInput();
860  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
861  auto resultTy = llvm::cast<RankedTensorType>(getType());
862  if (inputTy != resultTy)
863  return {};
864 
865  return input;
866 }
867 
868 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
869  auto operand = getInput();
870  auto operandTy = llvm::cast<ShapedType>(operand.getType());
871  auto axis = getAxis();
872  auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
873  if (operandAttr)
874  return operandAttr;
875 
876  // If the dim-length is 1, tosa.reverse is a no-op.
877  if (operandTy.hasRank() &&
878  (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
879  return operand;
880 
881  return {};
882 }
883 
884 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
885  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
886  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
887 
888  if (!inputTy || !outputTy)
889  return {};
890 
891  if (inputTy == outputTy && inputTy.hasStaticShape())
892  return getInput();
893 
894  if (!adaptor.getInput())
895  return {};
896 
897  // Cannot create an ElementsAttr from non-int/float/index types
898  if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
899  !outputTy.getElementType().isIntOrIndexOrFloat())
900  return {};
901 
902  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
903  if (operand.isSplat() && outputTy.hasStaticShape()) {
904  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
905  }
906 
907  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
908  outputTy.getNumElements() == 1) {
909  llvm::SmallVector<uint64_t> indices(getStart());
910  auto value = operand.getValues<Attribute>()[indices];
911  return SplatElementsAttr::get(outputTy, value);
912  }
913 
914  return {};
915 }
916 
917 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
918  if (getOnTrue() == getOnFalse())
919  return getOnTrue();
920 
921  auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
922  if (!predicate)
923  return {};
924 
925  if (!predicate.isSplat())
926  return {};
927  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
928  : getOnFalse();
929 }
930 
931 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
932  bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
933  if (allOnes && getInput1().getType() == getType())
934  return getInput1();
935  return {};
936 }
937 
938 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
939  auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
940  auto resultTy = llvm::cast<ShapedType>(getType());
941 
942  // Transposing splat values just means reshaping.
943  if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
944  if (input.isSplat() && resultTy.hasStaticShape() &&
945  inputTy.getElementType() == resultTy.getElementType())
946  return input.reshape(resultTy);
947  }
948 
949  // Transpose does not change the input type.
950  if (getInput1().getType() != getType())
951  return {};
952 
953  // Transpose is not the identity transpose.
954  SmallVector<int64_t> perms;
955  if (getConstantPerms(perms).failed())
956  return {};
957 
958  if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
959  return {};
960 
961  return getInput1();
962 }
963 
964 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
965  auto input = getInput1();
966  // Element-wise log(exp(x)) = x
967  if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
968  return op.getInput1();
969  }
970 
971  return {};
972 }
973 
974 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
975  auto input = getInput1();
976  // Element-wise exp(log(x)) = x
977  if (auto op = input.getDefiningOp<tosa::LogOp>()) {
978  return op.getInput1();
979  }
980 
981  return {};
982 }
983 
984 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
985  auto input = getInput1();
986  // Element-wise negate(negate(x)) = x
987  if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
988  return op.getInput1();
989  }
990 
991  return {};
992 }
993 
994 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
995  auto input = getInput1();
996  // Element-wise abs(abs(x)) = abs(x)
997  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
998  return input;
999  }
1000 
1001  return {};
1002 }
1003 
1004 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1005  // Fold consecutive concats on the same axis into a single op.
1006  // Keep track of the operands so we are able to construct a new concat
1007  // later. Conservatively assume that we double the number of operands when
1008  // folding
1009  SmallVector<Value, 8> concatOperands;
1010  concatOperands.reserve(2 * getNumOperands());
1011 
1012  // Find all operands that are foldable concats
1013  bool foundFoldableConcat = false;
1014  for (Value operand : getOperands()) {
1015  concatOperands.emplace_back(operand);
1016 
1017  auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1018  if (!producer)
1019  continue;
1020 
1021  // Not foldable if axes are not the same
1022  if (getAxis() != producer.getAxis())
1023  continue;
1024 
1025  // Replace the original operand with all incoming operands
1026  foundFoldableConcat = true;
1027  concatOperands.pop_back();
1028  llvm::append_range(concatOperands, producer->getOperands());
1029  }
1030 
1031  if (!foundFoldableConcat)
1032  return {};
1033 
1034  getOperation()->setOperands(concatOperands);
1035  return getResult();
1036 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:253
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_range getUsers() const
Definition: Value.h:224
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361