MLIR  21.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // TOSA canonicalization patterns and folders.
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <functional>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 //===----------------------------------------------------------------------===//
39 // Operator Canonicalizers.
40 //===----------------------------------------------------------------------===//
41 
42 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
44 
45  LogicalResult matchAndRewrite(tosa::ConcatOp op,
46  PatternRewriter &rewriter) const override {
47  if (op.getInput1().size() != 1)
48  return failure();
49  if (op.getInput1().front().getType() != op.getType()) {
50  rewriter
51  .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
52  op.getInput1().front())
53  .getResult();
54  return success();
55  }
56 
57  rewriter.replaceOp(op, op.getInput1().front());
58  return success();
59  }
60 };
61 
62 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
63  MLIRContext *context) {
64  results.add<ConcatOptimization>(context);
65 }
66 
67 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
68  auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
69  if (!notOp)
70  return failure();
71  rewriter.modifyOpInPlace(op, [&]() {
72  op.getOperation()->setOperands(
73  {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
74  });
75  return success();
76 }
77 
79  : public OpRewritePattern<tosa::TransposeOp> {
81 
82  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
83  PatternRewriter &rewriter) const override {
84  // Input is also TransposeOp - transpose(transpose(A)).
85  auto innerTranspose =
86  transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
87  if (!innerTranspose)
88  return rewriter.notifyMatchFailure(transposeOp,
89  "input must be transpose operation");
90 
91  SmallVector<int32_t> transposePerms, innerTransposePerms;
92  if (transposeOp.getConstantPerms(transposePerms).failed())
93  return rewriter.notifyMatchFailure(transposeOp,
94  "transpose perms must be constant");
95  if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
96  return rewriter.notifyMatchFailure(
97  transposeOp, "inner transpose perms must be constant");
98  if (transposePerms.size() != innerTransposePerms.size())
99  return rewriter.notifyMatchFailure(
100  transposeOp,
101  "transpose and inner transpose perms sizes must be equal");
102  if (transposePerms.empty())
103  return rewriter.notifyMatchFailure(
104  transposeOp, "transpose perms sizes must be positive");
105 
106  // Consolidate transposes into one transpose.
107  SmallVector<int32_t> perms(transposePerms.size());
108  for (int i = 0, s = transposePerms.size(); i < s; ++i)
109  perms[i] = innerTransposePerms[transposePerms[i]];
110 
111  auto permsTy =
112  RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
113  auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
114  Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
115  permsTy, permsAttr);
116 
117  rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
118  transposeOp, transposeOp.getResult().getType(),
119  innerTranspose.getInput1(), permsValue);
120 
121  return success();
122  }
123 };
124 
125 // Determines the case when tosa.transpose is a tosa.reshape operation.
126 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
128 
129  LogicalResult matchAndRewrite(tosa::TransposeOp op,
130  PatternRewriter &rewriter) const override {
131  DenseIntElementsAttr permAttr;
132  if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
133  return rewriter.notifyMatchFailure(op, "Non-constant permutation");
134 
135  if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
136  return rewriter.notifyMatchFailure(
137  op, "Src is from transpose, can compose transposes");
138 
139  Value result = op.getResult();
140  for (Operation *subop : result.getUsers()) {
141  if (dyn_cast_or_null<tosa::TransposeOp>(subop))
142  return rewriter.notifyMatchFailure(
143  op, "Dest is used by transpose, can compose transposes");
144  }
145 
146  auto input = op.getInput1();
147  auto inputTy = llvm::cast<ShapedType>(input.getType());
148  if (!inputTy.hasRank())
149  return rewriter.notifyMatchFailure(op, "Unranked input.");
150 
151  int64_t numDynDims = 0;
152  for (int i = 0; i < inputTy.getRank(); ++i)
153  if (inputTy.isDynamicDim(i))
154  numDynDims++;
155 
156  if (numDynDims > 1)
157  return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
158 
159  SmallVector<int64_t> permValues = llvm::to_vector<6>(
160  llvm::map_range(permAttr.getValues<APInt>(),
161  [](const APInt &val) { return val.getSExtValue(); }));
162 
163  SmallVector<int64_t> nonZeroPerms;
164  nonZeroPerms.reserve(permValues.size());
165  for (auto idx : permValues) {
166  auto sz = inputTy.getDimSize(idx);
167  if (sz != 1)
168  nonZeroPerms.push_back(idx);
169  }
170 
171  for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
172  if (nonZeroPerms[i - 1] > nonZeroPerms[i])
173  return rewriter.notifyMatchFailure(op,
174  "Transpose changes memory layout.");
175 
176  SmallVector<int64_t> newShape;
177  newShape.reserve(inputTy.getRank());
178  for (int i = 0, s = inputTy.getRank(); i < s; ++i)
179  newShape.push_back(inputTy.getDimSize(permValues[i]));
180 
181  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
182  op, op.getType(), op.getInput1(),
183  getTosaConstShape(rewriter, op.getLoc(), newShape));
184  return success();
185  }
186 };
187 
188 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
189  MLIRContext *context) {
191 }
192 
193 struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
195 
196  LogicalResult matchAndRewrite(tosa::PadOp op,
197  PatternRewriter &rewriter) const override {
198  if (op.getPadConst())
199  return failure();
200 
201  auto input = op.getInput1();
202  auto padding = op.getPadding();
203 
204  ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
205  Type elementTy = inputTy.getElementType();
206 
207  Attribute constantAttr;
208  if (llvm::isa<FloatType>(elementTy)) {
209  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
210  } else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
211  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
212  } else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
213  int64_t value = op.getInputZpAttr().getInt();
214  constantAttr = rewriter.getIntegerAttr(elementTy, value);
215  }
216 
217  if (!constantAttr) {
218  return rewriter.notifyMatchFailure(
219  op,
220  "tosa.pad to linalg lowering encountered an unknown element type");
221  }
222 
223  auto denseAttr = DenseElementsAttr::get(
224  RankedTensorType::get({}, elementTy), constantAttr);
225  auto constantVal = rewriter.create<tosa::ConstOp>(
226  op.getLoc(), denseAttr.getType(), denseAttr);
227 
228  rewriter.replaceOpWithNewOp<tosa::PadOp>(
229  op, op.getType(), ValueRange{input, padding, constantVal},
230  op->getAttrs());
231  return success();
232  }
233 };
234 
235 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
236  MLIRContext *context) {
237  results.add<MaterializePadValue>(context);
238 }
239 
240 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
242 
243  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
244  PatternRewriter &rewriter) const override {
245  Value input = op.getInput();
246  Value output = op.getOutput();
247  ShapedType inputType = llvm::cast<ShapedType>(input.getType());
248  ShapedType outputType = llvm::cast<ShapedType>(output.getType());
249 
250  if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
251  return failure();
252  }
253 
254  // If the output and input shapes are 1x1, then this is a no op.
255  ArrayRef<int64_t> outputShape = outputType.getShape();
256  if (outputShape[1] != 1 || outputShape[2] != 1) {
257  return failure();
258  }
259 
260  ArrayRef<int64_t> inputShape = inputType.getShape();
261  if (inputShape[1] != 1 || inputShape[2] != 1) {
262  return failure();
263  }
264 
265  rewriter.replaceOp(op, input);
266  return success();
267  }
268 };
269 
270 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
271  MLIRContext *context) {
272  results.add<MaxPool2dIsNoOp>(context);
273 }
274 
275 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
277 
278  LogicalResult matchAndRewrite(tosa::ClampOp op,
279  PatternRewriter &rewriter) const override {
280  Value input = op.getInput();
281  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
282  auto inputElementType = inputType.getElementType();
283 
284  if (!inputType.hasStaticShape()) {
285  return failure();
286  }
287 
288  if (isa<FloatType>(inputElementType)) {
289  // Unlike integer types, floating point types can represent infinity.
290  auto minClamp = op.getMinFp();
291  auto maxClamp = op.getMaxFp();
292  bool isMin = minClamp.isInfinity() && minClamp.isNegative();
293  bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
294 
295  if (isMin && isMax) {
296  rewriter.replaceOp(op, input);
297  return success();
298  }
299  return failure();
300  }
301 
302  if (inputElementType.isUnsignedInteger()) {
303  int64_t minClamp = op.getMinInt();
304  int64_t maxClamp = op.getMaxInt();
305 
306  int64_t intMin =
307  APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
308  .getZExtValue();
309  int64_t intMax =
310  APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
311  .getZExtValue();
312 
313  if (minClamp <= intMin && maxClamp >= intMax) {
314  rewriter.replaceOp(op, input);
315  return success();
316  }
317  return failure();
318  }
319 
320  if (llvm::isa<IntegerType>(inputElementType)) {
321  int64_t minClamp = op.getMinInt();
322  int64_t maxClamp = op.getMaxInt();
323 
324  int64_t intMin =
325  APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
326  .getSExtValue();
327  int64_t intMax =
328  APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
329  .getSExtValue();
330 
331  if (minClamp <= intMin && maxClamp >= intMax) {
332  rewriter.replaceOp(op, input);
333  return success();
334  }
335  return failure();
336  }
337 
338  return failure();
339  }
340 };
341 
342 // Attempts the following transformation:
343 //
344 // For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
345 // tensor X the following identity holds:
346 //
347 // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
348 //
349 // subject to the following valid NaN propagation semantics:
350 // --------------------------------------------
351 // | OUTER CLAMP | INNER CLAMP | RESULT MODE |
352 // |-------------|--------------|-------------|
353 // | PROPAGATE | PROPAGATE | PROPAGATE |
354 // | PROPAGATE | IGNORE | IGNORE |
355 // | IGNORE | PROPAGATE | INVALID |
356 // | IGNORE | IGNORE | IGNORE |
357 // |------------------------------------------|
358 
359 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
361 
362  // Helper structure to describe the range of a clamp operation.
363  template <typename T>
364  struct ClampRange {
365  ClampRange(const T &start, const T &end) : start(start), end(end) {}
366  T start;
367  T end;
368 
369  // Helper function to determine if two Clamp ranges intersect.
370  bool intersects(const ClampRange<T> &otherRange) {
371  return start < otherRange.end && otherRange.start < end;
372  }
373  };
374 
375  LogicalResult matchAndRewrite(tosa::ClampOp op,
376  PatternRewriter &rewriter) const override {
377  // Check the input to the CLAMP op is itself a CLAMP.
378  auto clampOp =
379  dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
380  if (!clampOp)
381  return failure();
382 
383  // Check we have a valid NaN propagation combination.
384  const auto opNanMode = op.getNanMode();
385  const auto clampNanMode = clampOp.getNanMode();
386  if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
387  return failure();
388 
389  // Check we have intersecting ranges.
390  const auto opMinInt = op.getMinInt();
391  const auto opMaxInt = op.getMaxInt();
392  const auto clampOpMinInt = clampOp.getMinInt();
393  const auto clampOpMaxInt = clampOp.getMaxInt();
394  ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
395  ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
396  if (!opRangeIntRange.intersects(clampRangeIntRange))
397  return failure();
398 
399  const auto opMinFloat = op.getMinFp();
400  const auto opMaxFloat = op.getMaxFp();
401  const auto clampOpMinFloat = clampOp.getMinFp();
402  const auto clampOpMaxFloat = clampOp.getMaxFp();
403  ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
404  ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
405  if (!opRangeFloatRange.intersects(clampRangeFloatRange))
406  return failure();
407 
408  // Run the transformation.
409  const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
410  const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
411  const auto minInt = std::max(opMinInt, clampOpMinInt);
412  const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
413  rewriter.replaceOpWithNewOp<tosa::ClampOp>(
414  op, op.getType(), clampOp.getInput(),
415  rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
416  rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
417  rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
418  : opNanMode));
419  return success();
420  }
421 };
422 
423 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
424  MLIRContext *context) {
425  results.add<ClampIsNoOp>(context);
426  results.add<ClampClampOptimization>(context);
427 }
428 
429 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
431 
432  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
433  PatternRewriter &rewriter) const override {
434  Value sliceInput = sliceOp.getInput1();
435  auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
436  if (!concatOp)
437  return rewriter.notifyMatchFailure(
438  sliceOp, "slice input must be concat operation");
439 
440  OperandRange inputs = concatOp.getInput1();
441  auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
442  if (!concatType || !concatType.hasStaticShape())
443  return rewriter.notifyMatchFailure(
444  sliceOp, "slice input must be a static ranked tensor");
445  int32_t axis = concatOp.getAxis();
446 
447  DenseElementsAttr startElems;
448  DenseElementsAttr sizeElems;
449 
450  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
451  return rewriter.notifyMatchFailure(
452  sliceOp, "start of slice must be a static ranked shape");
453 
454  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
455  return rewriter.notifyMatchFailure(
456  sliceOp, "size of slice must be a static ranked shape");
457 
458  llvm::SmallVector<int64_t> sliceStarts =
459  llvm::to_vector(startElems.getValues<int64_t>());
460  llvm::SmallVector<int64_t> sliceSizes =
461  llvm::to_vector(sizeElems.getValues<int64_t>());
462 
463  // Validate slice on the concatenated axis. Slicing along this
464  // axis should span only one of the inputs to the concatenate
465  // operation.
466  std::optional<Value> replaceWithSlice;
467  for (auto input : inputs) {
468  auto inputType = dyn_cast<RankedTensorType>(input.getType());
469  if (!inputType || !inputType.hasStaticShape())
470  return rewriter.notifyMatchFailure(
471  sliceOp, "concat input must be a static ranked tensor");
472 
473  if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
474  inputType.getDimSize(axis)) {
475  auto start_op =
476  getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
477  auto size_op =
478  getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
479  replaceWithSlice =
480  rewriter
481  .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
482  input, start_op, size_op)
483  .getResult();
484  break;
485  }
486  sliceStarts[axis] -= inputType.getDimSize(axis);
487  }
488 
489  if (!replaceWithSlice)
490  return rewriter.notifyMatchFailure(
491  sliceOp, "corresponding concat input not found for slice");
492 
493  rewriter.replaceOp(sliceOp, replaceWithSlice.value());
494  return success();
495  }
496 };
497 
498 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
499  MLIRContext *context) {
500  results.add<ConcatSliceOptimization>(context);
501 }
502 
503 //===----------------------------------------------------------------------===//
504 // Operator Folders.
505 //===----------------------------------------------------------------------===//
506 
507 template <typename IntFolder, typename FloatFolder>
509  RankedTensorType returnTy) {
510  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
511  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
512  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
513  if (lETy != rETy)
514  return {};
515 
516  if (llvm::isa<IntegerType>(lETy)) {
517  APInt l = lhs.getSplatValue<APInt>();
518  APInt r = rhs.getSplatValue<APInt>();
519  auto result = IntFolder()(l, r);
520  return DenseElementsAttr::get(returnTy, result);
521  }
522 
523  if (llvm::isa<FloatType>(lETy)) {
524  APFloat l = lhs.getSplatValue<APFloat>();
525  APFloat r = rhs.getSplatValue<APFloat>();
526  auto result = FloatFolder()(l, r);
527  return DenseElementsAttr::get(returnTy, result);
528  }
529  }
530 
531  return {};
532 }
533 
534 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
535  if (llvm::isa<FloatType>(elemType))
536  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
537  if (llvm::isa<IntegerType>(elemType))
538  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
539  return false;
540 }
541 
542 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
543  if (llvm::isa<FloatType>(elemType))
544  return val && val.isSplat() &&
545  val.getSplatValue<APFloat>().isExactlyValue(1.0);
546  if (llvm::isa<IntegerType>(elemType)) {
547  const int64_t shifted = 1LL << shift;
548  return val && val.isSplat() &&
549  val.getSplatValue<APInt>().getSExtValue() == shifted;
550  }
551  return false;
552 }
553 
554 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
555  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
556  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
557  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
558  if (!lhsTy || !rhsTy || !resultTy)
559  return {};
560 
561  // Cannot create an ElementsAttr from non-int/float/index types
562  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
563  !rhsTy.getElementType().isIntOrIndexOrFloat())
564  return {};
565 
566  auto resultETy = resultTy.getElementType();
567  auto lhsAttr =
568  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
569  auto rhsAttr =
570  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
571 
572  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
573  return getInput1();
574  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
575  return getInput2();
576 
577  if (!lhsAttr || !rhsAttr)
578  return {};
579 
580  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
581  resultTy);
582 }
583 
584 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
585  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
586  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
587  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
588  !outputTy.hasStaticShape())
589  return {};
590 
591  if (inputTy.getDimSize(getAxis()) == 1)
592  return DenseElementsAttr::get(outputTy, 0);
593 
594  return {};
595 }
596 
597 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
598  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
599  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
600  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
601  if (!lhsTy || !rhsTy || !resultTy)
602  return {};
603  if (lhsTy != rhsTy)
604  return {};
605 
606  // IntDivOp inputs must be integer type, no need to check for quantized type
607  auto resultETy = resultTy.getElementType();
608  auto lhsAttr =
609  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
610  auto rhsAttr =
611  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
612  if (lhsAttr && lhsAttr.isSplat()) {
613  if (llvm::isa<IntegerType>(resultETy) &&
614  lhsAttr.getSplatValue<APInt>().isZero())
615  return lhsAttr;
616  }
617 
618  if (rhsAttr && rhsAttr.isSplat()) {
619  if (llvm::isa<IntegerType>(resultETy) &&
620  rhsAttr.getSplatValue<APInt>().isOne())
621  return getInput1();
622  }
623 
624  if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
625  if (llvm::isa<IntegerType>(resultETy)) {
626  APInt l = lhsAttr.getSplatValue<APInt>();
627  APInt r = rhsAttr.getSplatValue<APInt>();
628  APInt result = l.sdiv(r);
629  return DenseElementsAttr::get(resultTy, result);
630  }
631  }
632 
633  return {};
634 }
635 
636 namespace {
637 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
638  RankedTensorType ty, int32_t shift) {
639  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
640  if (llvm::isa<IntegerType>(ty.getElementType())) {
641  APInt l = lhs.getSplatValue<APInt>();
642  APInt r = rhs.getSplatValue<APInt>();
643 
644  if (shift == 0) {
645  return DenseElementsAttr::get(ty, l * r);
646  }
647 
648  auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
649  l = l.sext(bitwidth * 2);
650  r = r.sext(bitwidth * 2);
651  auto result = l * r;
652  result.lshrInPlace(shift);
653  result = result.trunc(bitwidth);
654  return DenseElementsAttr::get(ty, result);
655  }
656 
657  if (llvm::isa<FloatType>(ty.getElementType())) {
658  APFloat l = lhs.getSplatValue<APFloat>();
659  APFloat r = rhs.getSplatValue<APFloat>();
660  APFloat result = l * r;
661  return DenseElementsAttr::get(ty, result);
662  }
663  }
664 
665  return {};
666 }
667 } // namespace
668 
669 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
670  auto lhs = getInput1();
671  auto rhs = getInput2();
672  auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
673  auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
674  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
675  if (!lhsTy || !rhsTy || !resultTy)
676  return {};
677 
678  auto resultETy = resultTy.getElementType();
679  auto lhsAttr =
680  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
681  auto rhsAttr =
682  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
683 
684  // Result right shift on i32_t data type only. For simplification, synthesize
685  // a zero shift for other data type.
686  int32_t shift = 0;
687  if (resultETy.isInteger(32)) {
688  ElementsAttr shift_elem;
689  if (getShift().getImpl()) {
690  if (!matchPattern(getShift(), m_Constant(&shift_elem)))
691  // cannot be folded when the shift value is unknown.
692  return {};
693  shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
694  }
695  }
696 
697  if (rhsTy == resultTy) {
698  if (isSplatZero(resultETy, lhsAttr))
699  return lhsAttr.resizeSplat(resultTy);
700  if (isSplatOne(resultETy, lhsAttr, shift))
701  return rhs;
702  }
703  if (lhsTy == resultTy) {
704  if (isSplatZero(resultETy, rhsAttr))
705  return rhsAttr.resizeSplat(resultTy);
706  if (isSplatOne(resultETy, rhsAttr, shift))
707  return lhs;
708  }
709 
710  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
711 }
712 
713 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
714  auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
715  auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
716  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
717  if (!lhsTy || !rhsTy || !resultTy)
718  return {};
719 
720  // Cannot create an ElementsAttr from non-int/float/index types
721  if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
722  !rhsTy.getElementType().isIntOrIndexOrFloat())
723  return {};
724 
725  auto resultETy = resultTy.getElementType();
726  auto lhsAttr =
727  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
728  auto rhsAttr =
729  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
730 
731  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
732  return getInput1();
733 
734  if (!lhsAttr || !rhsAttr)
735  return {};
736 
737  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
738  resultTy);
739 }
740 
741 namespace {
742 template <typename Cmp>
743 struct ComparisonFold {
744  ComparisonFold() = default;
745  APInt operator()(const APInt &l, const APInt &r) {
746  return APInt(1, Cmp()(l, r));
747  }
748 
749  APInt operator()(const APFloat &l, const APFloat &r) {
750  return APInt(1, Cmp()(l, r));
751  }
752 };
753 
754 struct APIntFoldGreater {
755  APIntFoldGreater() = default;
756  APInt operator()(const APInt &l, const APInt &r) {
757  return APInt(1, l.sgt(r));
758  }
759 };
760 
761 struct APIntFoldGreaterEqual {
762  APIntFoldGreaterEqual() = default;
763  APInt operator()(const APInt &l, const APInt &r) {
764  return APInt(1, l.sge(r));
765  }
766 };
767 } // namespace
768 
769 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
770  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
771  auto lhsAttr =
772  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
773  auto rhsAttr =
774  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
775 
776  if (!lhsAttr || !rhsAttr)
777  return {};
778 
779  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
780  lhsAttr, rhsAttr, resultTy);
781 }
782 
783 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
784  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
785  auto lhsAttr =
786  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
787  auto rhsAttr =
788  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
789 
790  if (!lhsAttr || !rhsAttr)
791  return {};
792 
793  return binaryFolder<APIntFoldGreaterEqual,
794  ComparisonFold<std::greater_equal<APFloat>>>(
795  lhsAttr, rhsAttr, resultTy);
796 }
797 
798 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
799  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
800  auto lhsAttr =
801  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
802  auto rhsAttr =
803  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
804  Value lhs = getInput1();
805  Value rhs = getInput2();
806  auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
807 
808  // If we are comparing an integer value to itself it is always true. We can
809  // not do this with float due to float values.
810  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
811  resultTy.hasStaticShape() && lhs == rhs) {
812  return DenseElementsAttr::get(resultTy, true);
813  }
814 
815  if (!lhsAttr || !rhsAttr)
816  return {};
817 
818  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
819  ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
820  resultTy);
821 }
822 
823 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
824  if (getInput().getType() == getType())
825  return getInput();
826 
827  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
828  if (!operand)
829  return {};
830 
831  auto inTy = llvm::cast<ShapedType>(getInput().getType());
832  auto outTy = llvm::cast<ShapedType>(getType());
833  auto inETy = inTy.getElementType();
834  auto outETy = outTy.getElementType();
835 
836  if (operand.isSplat()) {
837  if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
838  bool overflow;
839  auto splatVal = operand.getSplatValue<APFloat>();
840  auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
841  splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
842  &overflow);
843  return SplatElementsAttr::get(outTy, splatVal);
844  }
845 
846  if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
847  auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
848  APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
849  splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
850  llvm::RoundingMode::NearestTiesToEven);
851  return SplatElementsAttr::get(outTy, splatVal);
852  }
853 
854  if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
855  auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
856  auto intVal = APSInt(
857  llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
858  auto floatVal = operand.getSplatValue<APFloat>();
859  bool exact;
860  floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
861  &exact);
862  return SplatElementsAttr::get(outTy, intVal);
863  }
864 
865  if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
866  auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
867  bool trunc =
868  inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
869  auto intVal = operand.getSplatValue<APInt>();
870  auto bitwidth = outETy.getIntOrFloatBitWidth();
871 
872  if (trunc) {
873  intVal = intVal.trunc(bitwidth);
874  } else if (unsignIn) {
875  intVal = intVal.zext(bitwidth);
876  } else {
877  intVal = intVal.sext(bitwidth);
878  }
879 
880  return SplatElementsAttr::get(outTy, intVal);
881  }
882  }
883 
884  return {};
885 }
886 
887 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
888 
889 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
890 
891 #define REDUCE_FOLDER(OP) \
892  OpFoldResult OP::fold(FoldAdaptor adaptor) { \
893  ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
894  if (!inputTy.hasRank()) \
895  return {}; \
896  if (inputTy != getType()) \
897  return {}; \
898  if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
899  return getInput(); \
900  return {}; \
901  }
902 
903 REDUCE_FOLDER(ReduceAllOp)
904 REDUCE_FOLDER(ReduceAnyOp)
905 REDUCE_FOLDER(ReduceMaxOp)
906 REDUCE_FOLDER(ReduceMinOp)
907 REDUCE_FOLDER(ReduceProdOp)
908 REDUCE_FOLDER(ReduceSumOp)
909 #undef REDUCE_FOLDER
910 
911 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
912  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
913  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
914 
915  if (!inputTy || !outputTy)
916  return {};
917 
918  // Fold when the input and output types are the same. This is only safe when
919  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
920  // there may still be a productive reshape.
921  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
922  return getInput1();
923 
924  // reshape(reshape(x)) -> reshape(x)
925  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
926  getInput1().getDefiningOp())) {
927  getInput1Mutable().assign(reshapeOp.getInput1());
928  return getResult();
929  }
930 
931  // Cannot create an ElementsAttr from non-int/float/index types
932  if (!inputTy.getElementType().isIntOrIndexOrFloat())
933  return {};
934 
935  // reshape(const(x)) -> const(reshape-attr(x))
936  if (auto operand =
937  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
938  // Constants must have static shape.
939  if (!outputTy.hasStaticShape())
940  return {};
941 
942  // Okay to duplicate splat constants.
943  if (operand.isSplat())
944  return SplatElementsAttr::get(outputTy,
945  operand.getSplatValue<Attribute>());
946 
947  // Don't duplicate other constants.
948  if (!getInput1().hasOneUse())
949  return {};
950 
952  if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
953  return {};
954 
955  return operand.reshape(
956  llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
957  }
958 
959  return {};
960 }
961 
962 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
963  // If the pad is all zeros we can fold this operation away.
964  if (adaptor.getPadding() && getInput1().getType() == getType()) {
965  auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
966  if (densePad && densePad.isSplat() &&
967  densePad.getSplatValue<APInt>().isZero()) {
968  return getInput1();
969  }
970  }
971 
972  return {};
973 }
974 
975 // Fold away cases where a tosa.resize operation returns a copy
976 // of the input image.
977 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
978  ArrayRef<int64_t> offset = getOffset();
979  ArrayRef<int64_t> border = getBorder();
980  ArrayRef<int64_t> scale = getScale();
981 
982  // Check unit scaling.
983  if (scale[0] != scale[1] || scale[2] != scale[3]) {
984  return {};
985  }
986 
987  // There should be no offset.
988  if (offset[0] != 0 || offset[1] != 0) {
989  return {};
990  }
991 
992  // There should be no border.
993  if (border[0] != 0 || border[1] != 0) {
994  return {};
995  }
996 
997  auto input = getInput();
998  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
999  auto resultTy = llvm::cast<RankedTensorType>(getType());
1000  if (inputTy != resultTy)
1001  return {};
1002 
1003  return input;
1004 }
1005 
1006 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1007  auto operand = getInput1();
1008  auto operandTy = llvm::cast<ShapedType>(operand.getType());
1009  auto axis = getAxis();
1010  auto operandAttr =
1011  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1012  if (operandAttr)
1013  return operandAttr;
1014 
1015  // If the dim-length is 1, tosa.reverse is a no-op.
1016  if (operandTy.hasRank() &&
1017  (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1018  return operand;
1019 
1020  return {};
1021 }
1022 
1023 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1024  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1025  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1026 
1027  if (!inputTy || !outputTy)
1028  return {};
1029 
1030  if (inputTy == outputTy && inputTy.hasStaticShape())
1031  return getInput1();
1032 
1033  if (!adaptor.getInput1())
1034  return {};
1035 
1036  // Cannot create an ElementsAttr from non-int/float/index types
1037  if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1038  !outputTy.getElementType().isIntOrIndexOrFloat())
1039  return {};
1040 
1041  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1042  if (operand.isSplat() && outputTy.hasStaticShape()) {
1043  return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1044  }
1045 
1046  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1047  outputTy.getNumElements() == 1) {
1048  DenseElementsAttr startElems;
1049  if (!matchPattern(getStart(), m_Constant(&startElems)))
1050  return {};
1051 
1052  llvm::SmallVector<uint64_t> indices =
1053  llvm::to_vector(startElems.getValues<uint64_t>());
1054  auto value = operand.getValues<Attribute>()[indices];
1055  return SplatElementsAttr::get(outputTy, value);
1056  }
1057 
1058  return {};
1059 }
1060 
1061 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1062  if (getOnTrue() == getOnFalse())
1063  return getOnTrue();
1064 
1065  auto predicate =
1066  llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
1067  if (!predicate)
1068  return {};
1069 
1070  if (!predicate.isSplat())
1071  return {};
1072  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1073  : getOnFalse();
1074 }
1075 
1076 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1077  if (getInput1().getType() == getType()) {
1078  if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1079  adaptor.getMultiples())) {
1080  if (multiples.isSplat() &&
1081  multiples.getSplatValue<APInt>().getSExtValue() == 1)
1082  return getInput1();
1083  if (auto int_array_attr =
1084  llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1085  if (llvm::all_of(int_array_attr.getValues<APInt>(),
1086  [](APInt v) { return v.getSExtValue() == 1; }))
1087  return getInput1();
1088  }
1089  }
1090  }
1091  return {};
1092 }
1093 
1094 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1095  auto resultTy = llvm::cast<ShapedType>(getType());
1096 
1097  // Transposing splat values just means reshaping.
1098  if (auto input =
1099  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1100  if (input.isSplat() && resultTy.hasStaticShape() &&
1101  input.getType().getElementType() == resultTy.getElementType())
1102  return input.reshape(resultTy);
1103  }
1104 
1105  // Transpose is not the identity transpose.
1106  SmallVector<int32_t> perms;
1107  if (getConstantPerms(perms).failed())
1108  return {};
1109 
1110  if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1111  return {};
1112 
1113  return getInput1();
1114 }
1115 
1116 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
1117  auto input = getInput1();
1118  // Element-wise log(exp(x)) = x
1119  if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
1120  return op.getInput1();
1121  }
1122 
1123  return {};
1124 }
1125 
1126 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
1127  auto input = getInput1();
1128  // Element-wise exp(log(x)) = x
1129  if (auto op = input.getDefiningOp<tosa::LogOp>()) {
1130  return op.getInput1();
1131  }
1132 
1133  return {};
1134 }
1135 
1136 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1137  auto input = getInput1();
1138  // Element-wise negate(negate(x)) = x
1139  if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
1140  return op.getInput1();
1141  }
1142 
1143  return {};
1144 }
1145 
1146 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1147  auto input = getInput1();
1148  // Element-wise abs(abs(x)) = abs(x)
1149  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1150  return input;
1151  }
1152 
1153  return {};
1154 }
1155 
1156 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1157  // Fold consecutive concats on the same axis into a single op.
1158  // Keep track of the operands so we are able to construct a new concat
1159  // later. Conservatively assume that we double the number of operands when
1160  // folding
1161  SmallVector<Value, 8> concatOperands;
1162  concatOperands.reserve(2 * getNumOperands());
1163 
1164  // Find all operands that are foldable concats
1165  bool foundFoldableConcat = false;
1166  for (Value operand : getOperands()) {
1167  concatOperands.emplace_back(operand);
1168 
1169  auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1170  if (!producer)
1171  continue;
1172 
1173  // Not foldable if axes are not the same
1174  if (getAxis() != producer.getAxis())
1175  continue;
1176 
1177  // Replace the original operand with all incoming operands
1178  foundFoldableConcat = true;
1179  concatOperands.pop_back();
1180  llvm::append_range(concatOperands, producer->getOperands());
1181  }
1182 
1183  if (!foundFoldableConcat)
1184  return {};
1185 
1186  getOperation()->setOperands(concatOperands);
1187  return getResult();
1188 }
1189 
1190 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1191  auto input = adaptor.getInput1();
1192 
1193  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1194  // Fold splat inputs only.
1195  if (!inputAttr || !inputAttr.isSplat())
1196  return {};
1197 
1198  auto shapeType = llvm::cast<ShapedType>(getType());
1199  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1200  auto floatVal = inputAttr.getSplatValue<APFloat>();
1201  return DenseElementsAttr::get(shapeType,
1202  ReciprocalOp::calcOneElement(floatVal));
1203  }
1204 
1205  return {};
1206 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:242
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValue(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362