MLIR  22.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/OpDefinition.h"
27 #include "mlir/IR/PatternMatch.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/Sequence.h"
31 
32 #include <type_traits>
33 
34 using namespace mlir;
35 using namespace mlir::tosa;
36 
37 // Helper function to materialize the semantically correct compare and select
38 // operations given a binary operation with a specific NaN propagation mode.
39 //
40 // In the case of "PROPAGATE" semantics no compare and selection is required and
41 // this function does nothing.
42 //
43 // In the case of "IGNORE" semantics this function materializes a comparison of
44 // the current operands to the op which will return true for any NaN
45 // argument and then selects between the non-NaN operation argument and the
46 // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
47 // code:
48 //
49 // In the case that the op is operating on non floating point types we ignore
50 // the attribute completely, this is consistent with the TOSA spec which has
51 // the following wording: "This attribute is ignored by non floating-point
52 // types."
53 //
54 // binary<op>(lhs, rhs):
55 // result = op(lhs, rhs)
56 // if lhs == NaN return rhs
57 // if rhs == NaN return lhs
58 // return result
59 template <typename OpTy>
60 static Value
62  Value lhs, Value rhs, Value result) {
63  // NaN propagation has no meaning for non floating point types.
64  if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
65  return result;
66 
67  auto nanMode = op.getNanMode();
68  if (nanMode == NanPropagationMode::PROPAGATE)
69  return result;
70 
71  // Unordered comparison of NaN against itself will always return true.
72  Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
73  arith::CmpFPredicate::UNO, lhs, lhs);
74  Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
75  arith::CmpFPredicate::UNO, rhs, rhs);
76  Value rhsOrResult =
77  arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
78  return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
79  rhsOrResult);
80 }
81 
83  Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
84  ConversionPatternRewriter &rewriter) {
85  Location loc = op->getLoc();
86  auto elementTy =
87  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
88 
89  // tosa::AbsOp
90  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
91  return math::AbsFOp::create(rewriter, loc, resultTypes, args);
92 
93  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
94  auto zero = arith::ConstantOp::create(rewriter, loc,
95  rewriter.getZeroAttr(elementTy));
96  auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
97  return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
98  }
99 
100  // tosa::AddOp
101  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
102  return arith::AddFOp::create(rewriter, loc, resultTypes, args);
103 
104  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
105  return arith::AddIOp::create(rewriter, loc, resultTypes, args);
106 
107  // tosa::SubOp
108  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
109  return arith::SubFOp::create(rewriter, loc, resultTypes, args);
110 
111  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
112  return arith::SubIOp::create(rewriter, loc, resultTypes, args);
113 
114  // tosa::IntDivOp
115  if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
116  return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
117 
118  // tosa::ReciprocalOp
119  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
120  auto one =
121  arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
122  return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
123  }
124 
125  // tosa::MulOp
126  if (isa<tosa::MulOp>(op)) {
127  auto shiftVal = cast<tosa::MulOp>(op).getShift();
128  DenseElementsAttr shiftElem;
129  bool shiftIsConstant = true;
130  int32_t shift = 0;
131  if (matchPattern(shiftVal, m_Constant(&shiftElem)))
132  shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
133  else
134  shiftIsConstant = false;
135 
136  if (isa<FloatType>(elementTy)) {
137  if (shift != 0) {
138  (void)rewriter.notifyMatchFailure(op,
139  "Cannot have shift value for float");
140  return nullptr;
141  }
142  return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
143  args[1]);
144  }
145 
146  if (isa<IntegerType>(elementTy)) {
147  Value a = args[0];
148  Value b = args[1];
149 
150  if (shift > 0 || !shiftIsConstant) {
151  Value shiftConst;
152  if (shiftIsConstant)
153  shiftConst = arith::ConstantIntOp::create(rewriter, loc, shift,
154  /*bitwidth=*/8);
155 
156  if (!a.getType().isInteger(32))
157  a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
158 
159  if (!b.getType().isInteger(32))
160  b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
161 
162  auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
163  auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
164  RoundingMode::SINGLE_ROUND);
165  auto result =
166  tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
167  b, shiftAmount, roundingAttr);
168 
169  return result;
170  }
171 
172  int aWidth = a.getType().getIntOrFloatBitWidth();
173  int bWidth = b.getType().getIntOrFloatBitWidth();
174  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
175 
176  if (aWidth < cWidth)
177  a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
178  if (bWidth < cWidth)
179  b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
180 
181  return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
182  }
183  }
184 
185  // tosa::NegateOp
186  if (isa<tosa::NegateOp>(op)) {
187  auto negate = cast<tosa::NegateOp>(op);
188 
189  FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
190  if (failed(maybeInZp)) {
191  (void)rewriter.notifyMatchFailure(
192  op, "input1 zero point cannot be statically determined");
193  return nullptr;
194  }
195 
196  FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
197  if (failed(maybeOutZp)) {
198  (void)rewriter.notifyMatchFailure(
199  op, "output zero point cannot be statically determined");
200  return nullptr;
201  }
202 
203  int64_t inZp = *maybeInZp;
204  int64_t outZp = *maybeOutZp;
205 
206  if (isa<FloatType>(elementTy))
207  return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
208 
209  if (isa<IntegerType>(elementTy)) {
210  if (!inZp && !outZp) {
211  auto constant = arith::ConstantOp::create(
212  rewriter, loc, IntegerAttr::get(elementTy, 0));
213  return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
214  args[0]);
215  }
216 
217  // Compute the maximum value that can occur in the intermediate buffer.
218  const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
219  const int64_t zpAdd = inZp + outZp;
220  const int64_t maxValue =
221  APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
222  std::abs(zpAdd) + 1;
223 
224  // Convert that maximum value into the maximum bitwidth needed to
225  // represent it. We assume 48-bit numbers may be supported further in
226  // the pipeline.
227  int intermediateBitWidth = 64;
228  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
229  intermediateBitWidth = 16;
230  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
231  intermediateBitWidth = 32;
232  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
233  intermediateBitWidth = 48;
234  }
235 
236  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
237  Value zpAddValue = arith::ConstantOp::create(
238  rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
239 
240  // The negation can be applied by doing:
241  // outputValue = inZp + outZp - inputValue
242  auto ext =
243  arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]);
244  auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
245 
246  // Clamp to the negation range.
248  rewriter, loc, intermediateType,
249  APInt::getSignedMinValue(inputBitWidth).getSExtValue());
251  rewriter, loc, intermediateType,
252  APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
253  auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
254 
255  // Truncate to the final value.
256  return arith::TruncIOp::create(rewriter, loc, elementTy, clamp);
257  }
258  }
259 
260  // tosa::BitwiseAndOp
261  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
262  return arith::AndIOp::create(rewriter, loc, resultTypes, args);
263 
264  // tosa::BitwiseOrOp
265  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
266  return arith::OrIOp::create(rewriter, loc, resultTypes, args);
267 
268  // tosa::BitwiseNotOp
269  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
270  auto allOnesAttr = rewriter.getIntegerAttr(
271  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
272  auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
273  return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
274  }
275 
276  // tosa::BitwiseXOrOp
277  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
278  return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
279 
280  // tosa::LogicalLeftShiftOp
281  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
282  return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
283 
284  // tosa::LogicalRightShiftOp
285  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
286  return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
287 
288  // tosa::ArithmeticRightShiftOp
289  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
290  auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
291  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
292  if (!round) {
293  return result;
294  }
295 
296  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
297  auto one = arith::ConstantOp::create(rewriter, loc,
298  IntegerAttr::get(elementTy, 1));
299  auto zero = arith::ConstantOp::create(rewriter, loc,
300  IntegerAttr::get(elementTy, 0));
301  auto i1one =
302  arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
303 
304  // Checking that input2 != 0
305  auto shiftValueGreaterThanZero = arith::CmpIOp::create(
306  rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
307 
308  // Checking for the last bit of input1 to be 1
309  auto subtract =
310  arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
311  auto shifted =
312  arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
313  ->getResults();
314  auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
316  auto isInputOdd =
317  arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
318 
319  auto shouldRound = arith::AndIOp::create(
320  rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
321  auto extended =
322  arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
323  return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
324  }
325 
326  // tosa::ClzOp
327  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
328  return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
329  }
330 
331  // tosa::LogicalAnd
332  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
333  return arith::AndIOp::create(rewriter, loc, resultTypes, args);
334 
335  // tosa::LogicalNot
336  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
337  auto one = arith::ConstantOp::create(rewriter, loc,
338  rewriter.getIntegerAttr(elementTy, 1));
339  return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
340  }
341 
342  // tosa::LogicalOr
343  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
344  return arith::OrIOp::create(rewriter, loc, resultTypes, args);
345 
346  // tosa::LogicalXor
347  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
348  return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
349 
350  // tosa::PowOp
351  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
352  return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
353 
354  // tosa::RsqrtOp
355  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
356  return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
357 
358  // tosa::LogOp
359  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
360  return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
361 
362  // tosa::ExpOp
363  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
364  return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
365 
366  // tosa::SinOp
367  if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
368  return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
369 
370  // tosa::CosOp
371  if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
372  return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
373 
374  // tosa::TanhOp
375  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
376  return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
377 
378  // tosa::ErfOp
379  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
380  return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
381 
382  // tosa::GreaterOp
383  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
384  return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
385  args[0], args[1]);
386 
387  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
388  return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
389  args[0], args[1]);
390 
391  // tosa::GreaterEqualOp
392  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
393  return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
394  args[0], args[1]);
395 
396  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
397  return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
398  args[0], args[1]);
399 
400  // tosa::EqualOp
401  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
402  return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
403  args[0], args[1]);
404 
405  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
406  return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
407  args[0], args[1]);
408 
409  // tosa::SelectOp
410  if (isa<tosa::SelectOp>(op)) {
411  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
412  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
413  return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
414  }
415 
416  // tosa::MaximumOp
417  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
418  auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
419  return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
420  rewriter, args[0], args[1], max);
421  }
422 
423  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
424  return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
425  }
426 
427  // tosa::MinimumOp
428  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
429  auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
430  return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
431  rewriter, args[0], args[1], min);
432  }
433 
434  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
435  return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
436  }
437 
438  // tosa::CeilOp
439  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
440  return math::CeilOp::create(rewriter, loc, resultTypes, args);
441 
442  // tosa::FloorOp
443  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
444  return math::FloorOp::create(rewriter, loc, resultTypes, args);
445 
446  // tosa::ClampOp
447  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
448  bool losesInfo = false;
449  APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
450  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
451  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
452  APFloat::rmNearestTiesToEven, &losesInfo);
453  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
454  APFloat::rmNearestTiesToEven, &losesInfo);
455  auto min = arith::ConstantOp::create(
456  rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
457  auto max = arith::ConstantOp::create(
458  rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
459  auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
460 
461  auto clampOp = llvm::cast<tosa::ClampOp>(op);
462  const auto nanMode = clampOp.getNanMode();
463 
464  // NaN propagation has no meaning for non floating point types.
465  if (!isa<FloatType>(elementTy))
466  return result;
467 
468  // In the case of "PROPAGATE" semantics no compare and selection is
469  // required.
470  if (nanMode == NanPropagationMode::PROPAGATE)
471  return result;
472 
473  // In the case of "IGNORE" semantics materialize a comparison
474  // of the current operand to the reduction which will return true for a NaN
475  // argument and then selects between the initial reduction value and the
476  // calculated result based on whether the argument is NaN or not. In pseudo
477  // code:
478  //
479  // reduce<op>(x, init):
480  // result = op(init, x)
481  // return init if x == NaN else result
482 
483  // Unordered comparison of NaN against itself will always return true.
484  Value isNaN = arith::CmpFOp::create(
485  rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
486  // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
487  // is NaN.
488  return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result);
489  }
490 
491  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
492  auto intTy = cast<IntegerType>(elementTy);
493  int64_t min =
494  cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
495  int64_t max =
496  cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
497 
498  int64_t minRepresentable = std::numeric_limits<int64_t>::min();
499  int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
500  if (intTy.isUnsignedInteger()) {
501  minRepresentable = 0;
502  if (intTy.getIntOrFloatBitWidth() <= 63) {
503  maxRepresentable =
504  (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
505  .getZExtValue();
506  }
507  } else if (intTy.getIntOrFloatBitWidth() <= 64) {
508  // Ensure that min & max fit into signed n-bit constants.
509  minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
510  .getSExtValue();
511  maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
512  .getSExtValue();
513  }
514  // Ensure that the bounds are representable as n-bit signed/unsigned
515  // integers.
516  min = std::max(min, minRepresentable);
517  max = std::max(max, minRepresentable);
518  min = std::min(min, maxRepresentable);
519  max = std::min(max, maxRepresentable);
520 
521  auto minVal = arith::ConstantIntOp::create(rewriter, loc, min,
522  intTy.getIntOrFloatBitWidth());
523  auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max,
524  intTy.getIntOrFloatBitWidth());
525  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
526  intTy.isUnsignedInteger());
527  }
528 
529  // tosa::SigmoidOp
530  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
531  auto one =
532  arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
533  auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
534  auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
535  auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
536  return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
537  }
538 
539  // tosa::CastOp
540  if (isa<tosa::CastOp>(op)) {
541  Type srcTy = elementTy;
542  Type dstTy = resultTypes.front();
543  if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
544  (void)rewriter.notifyMatchFailure(op, "unsupported type");
545  return nullptr;
546  }
547 
548  bool bitExtend =
550 
551  if (srcTy == dstTy)
552  return args.front();
553 
554  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
555  return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
557 
558  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
559  return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
561 
562  // 1-bit integers need to be treated as signless.
563  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
564  return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
566 
567  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
568  return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
570 
571  // Unsigned integers need an unrealized cast so that they can be passed
572  // to UIToFP.
573  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
574  auto unrealizedCast =
575  UnrealizedConversionCastOp::create(
576  rewriter, loc,
577  rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
578  .getResult(0);
579  return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
580  unrealizedCast);
581  }
582 
583  // All other si-to-fp conversions should be handled by SIToFP.
584  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
585  return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
587 
588  // Casting to boolean, floats need to only be checked as not-equal to zero.
589  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
590  Value zero = arith::ConstantOp::create(rewriter, loc,
591  rewriter.getFloatAttr(srcTy, 0.0));
592  return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
593  args.front(), zero);
594  }
595 
596  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
597  auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
598 
599  const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
600  // Check whether neither int min nor int max can be represented in the
601  // input floating-point type due to too short exponent range.
602  if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
603  APFloat::semanticsMaxExponent(fltSemantics)) {
604  // Use cmp + select to replace infinites by int min / int max. Other
605  // integral values can be represented in the integer space.
606  auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
607  auto posInf = arith::ConstantOp::create(
608  rewriter, loc,
609  rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
610  APFloat::getInf(fltSemantics)));
611  auto negInf = arith::ConstantOp::create(
612  rewriter, loc,
613  rewriter.getFloatAttr(
614  getElementTypeOrSelf(srcTy),
615  APFloat::getInf(fltSemantics, /*Negative=*/true)));
616  auto overflow = arith::CmpFOp::create(
617  rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
618  auto underflow = arith::CmpFOp::create(
619  rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
620  auto intMin = arith::ConstantOp::create(
621  rewriter, loc,
622  rewriter.getIntegerAttr(
623  getElementTypeOrSelf(dstTy),
624  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
625  auto intMax = arith::ConstantOp::create(
626  rewriter, loc,
627  rewriter.getIntegerAttr(
628  getElementTypeOrSelf(dstTy),
629  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
630  auto maxClamped =
631  arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
632  return arith::SelectOp::create(rewriter, loc, underflow, intMin,
633  maxClamped);
634  }
635 
636  auto intMinFP = arith::ConstantOp::create(
637  rewriter, loc,
638  rewriter.getFloatAttr(
639  getElementTypeOrSelf(srcTy),
640  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
641  .getSExtValue()));
642 
643  // Check whether the mantissa has enough bits to represent int max.
644  if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
645  dstTy.getIntOrFloatBitWidth() - 1) {
646  // Int min can also be represented since it is a power of two and thus
647  // consists of a single leading bit. Therefore we can clamp the input
648  // in the floating-point domain.
649 
650  auto intMaxFP = arith::ConstantOp::create(
651  rewriter, loc,
652  rewriter.getFloatAttr(
653  getElementTypeOrSelf(srcTy),
654  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
655  .getSExtValue()));
656 
657  Value clamped =
658  clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
659  return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
660  }
661 
662  // Due to earlier check we know exponant range is big enough to represent
663  // int min. We can therefore rely on int max + 1 being representable as
664  // well because it's just int min with a positive sign. So clamp the min
665  // value and compare against that to select the max int value if needed.
666  auto intMaxPlusOneFP = arith::ConstantOp::create(
667  rewriter, loc,
668  rewriter.getFloatAttr(
669  getElementTypeOrSelf(srcTy),
670  static_cast<double>(
671  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
672  .getSExtValue()) +
673  1.0f));
674 
675  auto intMax = arith::ConstantOp::create(
676  rewriter, loc,
677  rewriter.getIntegerAttr(
678  getElementTypeOrSelf(dstTy),
679  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
680  auto minClampedFP =
681  arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
682  auto minClamped =
683  arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
684  auto overflow = arith::CmpFOp::create(
685  rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
686  return arith::SelectOp::create(rewriter, loc, overflow, intMax,
687  minClamped);
688  }
689 
690  // Casting to boolean, integers need to only be checked as not-equal to
691  // zero.
692  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
693  Value zero = arith::ConstantIntOp::create(rewriter, loc, 0,
694  srcTy.getIntOrFloatBitWidth());
695  return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
696  args.front(), zero);
697  }
698 
699  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
700  return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
702 
703  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
704  return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
705  }
706  }
707 
708  (void)rewriter.notifyMatchFailure(
709  op, "unhandled op for linalg body calculation for elementwise op");
710  return nullptr;
711 }
712 
714 
715 // Emit an 'arith.constant' op for the given index if it has not been created
716 // yet, or return an existing constant. This will prevent an excessive creation
717 // of redundant constants, easing readability of emitted code for unit tests.
719  IndexPool &indexPool, int64_t index) {
720  auto [it, inserted] = indexPool.try_emplace(index);
721  if (inserted)
722  it->second =
723  arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index));
724  return it->second;
725 }
726 
728  IndexPool &indexPool, Value tensor, int64_t index) {
729  auto indexValue = createIndex(rewriter, loc, indexPool, index);
730  return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
731 }
732 
734  IndexPool &indexPool, Value tensor,
735  int64_t index) {
736  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
737  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
738  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
739  if (shapedType.isDynamicDim(index))
740  return getTensorDim(rewriter, loc, indexPool, tensor, index);
741  return rewriter.getIndexAttr(shapedType.getDimSize(index));
742 }
743 
744 static bool operandsAndResultsRanked(Operation *operation) {
745  auto isRanked = [](Value value) {
746  return isa<RankedTensorType>(value.getType());
747  };
748  return llvm::all_of(operation->getOperands(), isRanked) &&
749  llvm::all_of(operation->getResults(), isRanked);
750 }
751 
752 // Compute the runtime dimension size for dimension 'dim' of the output by
753 // inspecting input 'operands', all of which are expected to have the same rank.
754 // This function returns a pair {targetSize, masterOperand}.
755 //
756 // The runtime size of the output dimension is returned either as a statically
757 // computed attribute or as a runtime SSA value.
758 //
759 // If the target size was inferred directly from one dominating operand, that
760 // operand is returned in 'masterOperand'. If the target size is inferred from
761 // multiple operands, 'masterOperand' is set to nullptr.
762 static std::pair<OpFoldResult, Value>
764  ValueRange operands, int64_t dim) {
765  // If any input operand contains a static size greater than 1 for this
766  // dimension, that is the target size. An occurrence of an additional static
767  // dimension greater than 1 with a different value is undefined behavior.
768  for (auto operand : operands) {
769  auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
770  if (ShapedType::isStatic(size) && size > 1)
771  return {rewriter.getIndexAttr(size), operand};
772  }
773 
774  // Filter operands with dynamic dimension
775  auto operandsWithDynamicDim =
776  llvm::filter_to_vector(operands, [&](Value operand) {
777  return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
778  });
779 
780  // If no operand has a dynamic dimension, it means all sizes were 1
781  if (operandsWithDynamicDim.empty())
782  return {rewriter.getIndexAttr(1), operands.front()};
783 
784  // Emit code that computes the runtime size for this dimension. If there is
785  // only one operand with a dynamic dimension, it is considered the master
786  // operand that determines the runtime size of the output dimension.
787  auto targetSize =
788  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
789  if (operandsWithDynamicDim.size() == 1)
790  return {targetSize, operandsWithDynamicDim[0]};
791 
792  // Calculate maximum size among all dynamic dimensions
793  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
794  auto nextSize =
795  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
796  targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
797  }
798  return {targetSize, nullptr};
799 }
800 
801 // Compute the runtime output size for all dimensions. This function returns
802 // a pair {targetShape, masterOperands}.
803 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
805  IndexPool &indexPool, ValueRange operands) {
806  assert(!operands.empty());
807  auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
808  SmallVector<OpFoldResult> targetShape;
809  SmallVector<Value> masterOperands;
810  for (auto dim : llvm::seq<int64_t>(0, rank)) {
811  auto [targetSize, masterOperand] =
812  computeTargetSize(rewriter, loc, indexPool, operands, dim);
813  targetShape.push_back(targetSize);
814  masterOperands.push_back(masterOperand);
815  }
816  return {targetShape, masterOperands};
817 }
818 
820  IndexPool &indexPool, Value operand,
821  int64_t dim, OpFoldResult targetSize,
822  Value masterOperand) {
823  // Nothing to do if this is a static dimension
824  auto rankedTensorType = cast<RankedTensorType>(operand.getType());
825  if (!rankedTensorType.isDynamicDim(dim))
826  return operand;
827 
828  // If the target size for this dimension was directly inferred by only taking
829  // this operand into account, there is no need to broadcast. This is an
830  // optimization that will prevent redundant control flow, and constitutes the
831  // main motivation for tracking "master operands".
832  if (operand == masterOperand)
833  return operand;
834 
835  // Affine maps for 'linalg.generic' op
836  auto rank = rankedTensorType.getRank();
837  SmallVector<AffineExpr> affineExprs;
838  for (auto index : llvm::seq<int64_t>(0, rank)) {
839  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
840  : rewriter.getAffineDimExpr(index);
841  affineExprs.push_back(affineExpr);
842  }
843  auto broadcastAffineMap =
844  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
845  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
846  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
847 
848  // Check if broadcast is necessary
849  auto one = createIndex(rewriter, loc, indexPool, 1);
850  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
851  auto broadcastNecessary = arith::CmpIOp::create(
852  rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
853 
854  // Emit 'then' region of 'scf.if'
855  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
856  // It is not safe to cache constants across regions.
857  // New constants could potentially violate dominance requirements.
858  IndexPool localPool;
859 
860  // Emit 'tensor.empty' op
861  SmallVector<OpFoldResult> outputTensorShape;
862  for (auto index : llvm::seq<int64_t>(0, rank)) {
863  auto size = index == dim ? targetSize
864  : getOrFoldTensorDim(rewriter, loc, localPool,
865  operand, index);
866  outputTensorShape.push_back(size);
867  }
868  Value outputTensor = tensor::EmptyOp::create(
869  opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
870 
871  // Emit 'linalg.generic' op
872  auto resultTensor =
873  linalg::GenericOp::create(
874  opBuilder, loc, outputTensor.getType(), operand, outputTensor,
875  affineMaps, getNParallelLoopsAttrs(rank),
876  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
877  // Emit 'linalg.yield' op
878  linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
879  })
880  .getResult(0);
881 
882  // Cast to original operand type if necessary
883  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
884  loc, operand.getType(), resultTensor);
885 
886  // Emit 'scf.yield' op
887  scf::YieldOp::create(opBuilder, loc, castResultTensor);
888  };
889 
890  // Emit 'else' region of 'scf.if'
891  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
892  scf::YieldOp::create(opBuilder, loc, operand);
893  };
894 
895  // Emit 'scf.if' op
896  auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
897  emitThenRegion, emitElseRegion);
898  return ifOp.getResult(0);
899 }
900 
902  IndexPool &indexPool, Value operand,
903  ArrayRef<OpFoldResult> targetShape,
904  ArrayRef<Value> masterOperands) {
905  int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
906  assert((int64_t)targetShape.size() == rank);
907  assert((int64_t)masterOperands.size() == rank);
908  for (auto index : llvm::seq<int64_t>(0, rank))
909  operand =
910  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
911  targetShape[index], masterOperands[index]);
912  return operand;
913 }
914 
915 static SmallVector<Value>
917  IndexPool &indexPool, ValueRange operands,
918  ArrayRef<OpFoldResult> targetShape,
919  ArrayRef<Value> masterOperands) {
920  // No need to broadcast for unary operations
921  if (operands.size() == 1)
922  return operands;
923 
924  // No need to broadcast for static shape
925  bool hasDynamic = false;
926  for (auto op : operands) {
927  const auto tType = dyn_cast<RankedTensorType>(op.getType());
928  if (tType && !tType.hasStaticShape()) {
929  hasDynamic = true;
930  break;
931  }
932  }
933  if (!hasDynamic)
934  return operands;
935 
936  // Broadcast dynamic dimensions operand by operand
937  return llvm::map_to_vector(operands, [&](Value operand) {
938  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
939  targetShape, masterOperands);
940  });
941 }
942 
943 static LogicalResult
945  Operation *operation, ValueRange operands,
946  ArrayRef<OpFoldResult> targetShape,
947  const TypeConverter &converter) {
948  // Generate output tensor
949  auto resultType = cast_or_null<RankedTensorType>(
950  converter.convertType(operation->getResultTypes().front()));
951  if (!resultType) {
952  return rewriter.notifyMatchFailure(operation, "failed to convert type");
953  }
954  Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
955  resultType.getElementType());
956 
957  // Create affine maps. Input affine maps broadcast static dimensions of size
958  // 1. The output affine map is an identity map.
959  //
960  auto rank = resultType.getRank();
961  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
962  auto shape = cast<ShapedType>(operand.getType()).getShape();
963  SmallVector<AffineExpr> affineExprs;
964  for (auto it : llvm::enumerate(shape)) {
965  // Prefer producting identity maps whenever possible (i.e. no broadcasting
966  // needed) because some transforms (like reshape folding)
967  // do not support affine constant exprs.
968  bool requiresBroadcast =
969  (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
970  auto affineExpr = requiresBroadcast
971  ? rewriter.getAffineConstantExpr(0)
972  : rewriter.getAffineDimExpr(it.index());
973  affineExprs.push_back(affineExpr);
974  }
975  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
976  });
977  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
978 
979  // Emit 'linalg.generic' op
980  bool encounteredError = false;
981  auto linalgOp = linalg::GenericOp::create(
982  rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps,
984  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
986  operation, blockArgs.take_front(operation->getNumOperands()),
987  {resultType.getElementType()}, rewriter);
988  if (!opResult) {
989  encounteredError = true;
990  return;
991  }
992  linalg::YieldOp::create(opBuilder, loc, opResult);
993  });
994  if (encounteredError)
995  return rewriter.notifyMatchFailure(
996  operation, "unable to create linalg.generic body for elementwise op");
997 
998  // Cast 'linalg.generic' result into original result type if needed
999  auto castResult = rewriter.createOrFold<tensor::CastOp>(
1000  loc, resultType, linalgOp->getResult(0));
1001  rewriter.replaceOp(operation, castResult);
1002  return success();
1003 }
1004 
1006  ValueRange operands) {
1007  // Shift cannot broadcast
1008  if (isa<tosa::MulOp>(operation)) {
1009  DenseElementsAttr shiftElems;
1010  // Shift cannot broadcast when it is constant
1011  if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
1012  return operands.take_front(2);
1013  else
1014  return operands.take_front(3);
1015  }
1016  // Input1_zp and output_zp cannot broadcast
1017  if (isa<tosa::NegateOp>(operation))
1018  return operands.take_front(1);
1019  return operands;
1020 }
1021 
1022 static LogicalResult
1024  ConversionPatternRewriter &rewriter,
1025  const TypeConverter &converter) {
1026 
1027  // Collect op properties
1028  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1029  assert(operation->getNumOperands() >= 1 &&
1030  "elementwise op expects at least 1 operand");
1031  if (!operandsAndResultsRanked(operation))
1032  return rewriter.notifyMatchFailure(operation,
1033  "Unranked tensors not supported");
1034 
1035  // Lower operation
1036  IndexPool indexPool;
1037  auto loc = operation->getLoc();
1038  auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1039  auto [targetShape, masterOperands] =
1040  computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
1041  auto broadcastOperands =
1042  broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
1043  targetShape, masterOperands);
1044  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
1045  targetShape, converter);
1046 }
1047 
1048 // Returns the constant initial value for a given reduction operation. The
1049 // attribute type varies depending on the element type required.
1050 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1051  PatternRewriter &rewriter) {
1052  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1053  return rewriter.getFloatAttr(elementTy, 0.0);
1054 
1055  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1056  return rewriter.getIntegerAttr(elementTy, 0);
1057 
1058  if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1059  return rewriter.getFloatAttr(elementTy, 1.0);
1060 
1061  if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1062  return rewriter.getIntegerAttr(elementTy, 1);
1063 
1064  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1065  return rewriter.getFloatAttr(
1066  elementTy, APFloat::getLargest(
1067  cast<FloatType>(elementTy).getFloatSemantics(), false));
1068 
1069  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1070  return rewriter.getIntegerAttr(
1071  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
1072 
1073  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1074  return rewriter.getFloatAttr(
1075  elementTy, APFloat::getLargest(
1076  cast<FloatType>(elementTy).getFloatSemantics(), true));
1077 
1078  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1079  return rewriter.getIntegerAttr(
1080  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1081 
1082  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1083  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1084 
1085  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1086  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1087 
1088  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1089  return rewriter.getFloatAttr(
1090  elementTy, APFloat::getLargest(
1091  cast<FloatType>(elementTy).getFloatSemantics(), true));
1092 
1093  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1094  return rewriter.getIntegerAttr(
1095  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1096 
1097  return {};
1098 }
1099 
1100 // Creates the body calculation for a reduction. The operations vary depending
1101 // on the input type.
1103  ValueRange args,
1104  Type elementTy,
1105  PatternRewriter &rewriter) {
1106  Location loc = op->getLoc();
1107  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1108  return arith::AddFOp::create(rewriter, loc, args);
1109  }
1110 
1111  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1112  return arith::AddIOp::create(rewriter, loc, args);
1113  }
1114 
1115  if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1116  return arith::MulFOp::create(rewriter, loc, args);
1117  }
1118 
1119  if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1120  return arith::MulIOp::create(rewriter, loc, args);
1121  }
1122 
1123  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1124  return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1125  }
1126 
1127  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1128  return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1129  }
1130 
1131  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1132  return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1133  }
1134 
1135  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1136  return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1137  }
1138 
1139  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1140  return arith::AndIOp::create(rewriter, loc, args);
1141 
1142  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1143  return arith::OrIOp::create(rewriter, loc, args);
1144 
1145  return {};
1146 }
1147 
1148 // Performs the match and rewrite for reduction operations. This includes
1149 // declaring a correctly sized initial value, and the linalg.generic operation
1150 // that reduces across the specified axis.
1151 template <typename OpTy>
1152 static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1153  PatternRewriter &rewriter) {
1154  auto loc = op->getLoc();
1155  auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1156  auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1157  if (!inputTy || !resultTy)
1158  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1159 
1160  auto elementTy = resultTy.getElementType();
1161  Value input = op->getOperand(0);
1162 
1163  // Figure out the accType if needed
1164  bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1165  isa<FloatType>(elementTy) &&
1166  cast<FloatType>(elementTy).isBF16();
1167  Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
1168 
1169  SmallVector<int64_t> reduceShape;
1170  SmallVector<Value> dynDims;
1171  for (unsigned i = 0; i < inputTy.getRank(); i++) {
1172  if (axis != i) {
1173  reduceShape.push_back(inputTy.getDimSize(i));
1174  if (inputTy.isDynamicDim(i))
1175  dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1176  }
1177  }
1178 
1179  SmallVector<Value> inputs, outputs;
1180  inputs.push_back(input);
1181 
1182  // First fill the output buffer with the init value.
1183  auto emptyTensor =
1184  tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1185  .getResult();
1186 
1187  auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
1188  if (!fillValueAttr)
1189  return rewriter.notifyMatchFailure(
1190  op, "No initial value found for reduction operation");
1191 
1192  auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1193  auto filledTensor =
1194  linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
1195  ValueRange{emptyTensor})
1196  .result();
1197  outputs.push_back(filledTensor);
1198 
1199  bool isNanIgnoreMode = false;
1200  if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1201  std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1202  // NaN propagation has no meaning for non floating point types.
1203  if (isa<FloatType>(elementTy) &&
1204  op.getNanMode() == NanPropagationMode::IGNORE) {
1205  isNanIgnoreMode = true;
1206  // Because the TOSA spec requires the result be NaN iff all elements in
1207  // the reduction are NaN we can't simply perform a compare and select.
1208  // Additionally we have to keep track of whether we've seen any non-NaN
1209  // values and then do a final select based on this predicate.
1210  auto trueAttr = rewriter.getBoolAttr(true);
1211  auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1212  auto emptyBoolTensor =
1213  tensor::EmptyOp::create(rewriter, loc, reduceShape,
1214  trueValue.getType(), dynDims)
1215  .getResult();
1216  auto allResultsNaNTensor =
1217  linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
1218  ValueRange{emptyBoolTensor})
1219  .result();
1220  // Note that because the linalg::ReduceOp has two variadic arguments
1221  // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1222  // need to have the same number of inputs and outputs.
1223  //
1224  // The second input isn't actually used anywhere since the value used to
1225  // update the NaN flag is calculated inside the body of the reduction and
1226  // then used to update an out value.
1227  // In order to satisfy type constraints we just pass another copy of the
1228  // input here.
1229  inputs.push_back(input);
1230  outputs.push_back(allResultsNaNTensor);
1231  }
1232  }
1233 
1234  bool didEncounterError = false;
1235  linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1236  rewriter, loc, inputs, outputs, axis,
1237  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1238  std::array<Value, 2> binaryArgs{
1239  blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1240 
1241  // If reduction type differs then extend (applicable to reduce_sum)
1242  if (binaryArgs[0].getType() != accTy)
1243  binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1244  binaryArgs[0]);
1245 
1246  auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
1247  accTy, rewriter);
1248  if (result)
1249  didEncounterError = true;
1250 
1251  SmallVector<Value> resultsToYield;
1252  if (isNanIgnoreMode) {
1253  auto inputValue = blockArgs[0];
1254  auto initialValue = blockArgs[2];
1255  auto oldAllResultsNanFlagValue = blockArgs[3];
1256 
1257  // Unordered comparison of NaN against itself will always return true.
1258  Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1259  arith::CmpFPredicate::UNO,
1260  inputValue, inputValue);
1261  // If we've encountered a NaN, take the non-NaN value.
1262  auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1263  isNaN, initialValue, result);
1264  // Update the flag which keeps track of whether we have seen a non-NaN
1265  // value.
1266  auto newAllResultsNanFlagValue = arith::AndIOp::create(
1267  nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1268  resultsToYield.push_back(selectOp);
1269  resultsToYield.push_back(newAllResultsNanFlagValue);
1270  } else {
1271  resultsToYield.push_back(result);
1272  }
1273  linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1274  });
1275 
1276  if (!didEncounterError)
1277  return rewriter.notifyMatchFailure(
1278  op, "unable to create linalg.generic body for reduce op");
1279 
1280  if (isNanIgnoreMode) {
1281  // Materialize a check to see whether we encountered any non-NaN values, if
1282  // we didn't we need to select a tensor of NaNs since the result will just
1283  // be the initial identity value propagated through all the compares and
1284  // selects inside the reduction.
1285 
1286  // Create a tensor full of NaNs.
1287  auto nanValueAttr = rewriter.getFloatAttr(
1288  accTy,
1289  APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1290  auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1291  auto emptyNanTensor =
1292  tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1293  .getResult();
1294  auto nanFilledTensor =
1295  linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
1296  ValueRange{emptyNanTensor})
1297  .result();
1298 
1299  // Create an empty tensor, non need to fill this since it will be
1300  // overwritten by the select.
1301  auto finalEmptyTensor =
1302  tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1303  .getResult();
1304 
1305  // Do a selection between the tensors akin to:
1306  // result = NaN if "all results NaN" else result.
1307  SmallVector<Value> ins, outs;
1308  ins.push_back(linalgOp->getOpResult(1));
1309  ins.push_back(nanFilledTensor);
1310  ins.push_back(linalgOp->getResult(0));
1311  outs.push_back(finalEmptyTensor);
1312  auto linalgSelect =
1313  linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1314  linalgOp = linalgSelect;
1315  }
1316 
1317  // Truncate back to resultTy if needed
1318  Value reducedRes = linalgOp->getResult(0);
1319  if (widenAccTy) {
1320  auto resEmptyOp =
1321  tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1322  .getResult();
1323 
1324  const unsigned reducedRank =
1325  cast<ShapedType>(reducedRes.getType()).getRank();
1326  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
1327  reducedRes =
1328  linalg::GenericOp::create(
1329  rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes},
1330  ValueRange{resEmptyOp},
1331  ArrayRef<AffineMap>{identityMap, identityMap},
1332  getNParallelLoopsAttrs(reducedRank),
1333  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1334  Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1335  elementTy, args[0]);
1336  linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1337  })
1338  .getResults()[0];
1339  }
1340 
1341  SmallVector<ReassociationExprs, 4> reassociationMap;
1342  uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
1343  reassociationMap.resize(expandInputRank);
1344 
1345  for (uint64_t i = 0; i < expandInputRank; i++) {
1346  int32_t dimToPush = i > axis ? i + 1 : i;
1347  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1348  }
1349 
1350  if (expandInputRank != 0) {
1351  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1352  reassociationMap[expandedDim].push_back(
1353  rewriter.getAffineDimExpr(expandedDim + 1));
1354  }
1355 
1356  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1357  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1358  // not have access to such information. This matters when handling dynamically
1359  // sized tensors.
1360  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
1361  reassociationMap);
1362  return success();
1363 }
1364 
1365 namespace {
1366 
1367 template <typename SrcOp>
1368 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1369 public:
1372 
1373  LogicalResult
1374  matchAndRewrite(SrcOp op, OpAdaptor operands,
1375  ConversionPatternRewriter &rewriter) const final {
1377  op, operands.getOperands(), rewriter, *this->getTypeConverter());
1378  }
1379 };
1380 
1381 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1382 public:
1384 
1385  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1386  PatternRewriter &rewriter) const final {
1387  auto loc = op.getLoc();
1388  auto input = op.getInput();
1389  auto inputTy = cast<ShapedType>(op.getInput().getType());
1390  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1391  unsigned rank = inputTy.getRank();
1392 
1393  // This is an illegal configuration. terminate and log an error
1394  if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1395  return rewriter.notifyMatchFailure(
1396  op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1397  "currently supported");
1398  if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1399  return rewriter.notifyMatchFailure(
1400  op, "tosa.rescale requires scale32 for double_round to be true");
1401 
1402  if (!isa<IntegerType>(inputTy.getElementType()))
1403  return rewriter.notifyMatchFailure(op, "only support integer type");
1404 
1405  SmallVector<Value> dynDims;
1406  for (int i = 0; i < outputTy.getRank(); i++) {
1407  if (outputTy.isDynamicDim(i)) {
1408  dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1409  }
1410  }
1411 
1412  // The shift and multiplier values.
1413  DenseElementsAttr shiftElems;
1414  if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1415  return rewriter.notifyMatchFailure(
1416  op, "tosa.rescale requires constant shift input values");
1417 
1418  DenseElementsAttr multiplierElems;
1419  if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1420  return rewriter.notifyMatchFailure(
1421  op, "tosa.rescale requires constant multiplier input values");
1422 
1423  llvm::SmallVector<int8_t> shiftValues =
1424  llvm::to_vector(shiftElems.getValues<int8_t>());
1425  // explicit cast is required here
1426  llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1427  llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1428  [](IntegerAttr attr) -> int32_t {
1429  return static_cast<int32_t>(attr.getInt());
1430  }));
1431 
1432  // If we shift by more than the bitwidth, this just sets to 0.
1433  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1434  if (shiftValues[i] > 63) {
1435  shiftValues[i] = 0;
1436  multiplierValues[i] = 0;
1437  }
1438  }
1439 
1440  // Double round only occurs if shift is greater than 31, check that this
1441  // is ever true.
1442 
1443  bool doubleRound =
1444  op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1445  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1446  RoundingMode roundingMode =
1447  doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1448 
1449  SmallVector<AffineMap> indexingMaps = {
1450  rewriter.getMultiDimIdentityMap(rank)};
1451  SmallVector<Value, 4> genericInputs = {input};
1452 
1453  // If we are rescaling per-channel then we need to store the multiplier
1454  // values in a buffer.
1455  Value multiplierConstant;
1456  int64_t multiplierArg = 0;
1457  if (multiplierValues.size() == 1) {
1458  multiplierConstant = arith::ConstantOp::create(
1459  rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1460  } else {
1461  SmallVector<AffineExpr, 2> multiplierExprs{
1462  rewriter.getAffineDimExpr(rank - 1)};
1463  auto multiplierType =
1464  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1465  rewriter.getI32Type());
1466  genericInputs.push_back(arith::ConstantOp::create(
1467  rewriter, loc,
1468  DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1469 
1470  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1471  /*symbolCount=*/0, multiplierExprs,
1472  rewriter.getContext()));
1473 
1474  multiplierArg = indexingMaps.size() - 1;
1475  }
1476 
1477  // If we are rescaling per-channel then we need to store the shift
1478  // values in a buffer.
1479  Value shiftConstant;
1480  int64_t shiftArg = 0;
1481  if (shiftValues.size() == 1) {
1482  shiftConstant = arith::ConstantOp::create(
1483  rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1484  } else {
1485  SmallVector<AffineExpr, 2> shiftExprs = {
1486  rewriter.getAffineDimExpr(rank - 1)};
1487  auto shiftType =
1488  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1489  rewriter.getIntegerType(8));
1490  genericInputs.push_back(arith::ConstantOp::create(
1491  rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1492  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1493  /*symbolCount=*/0, shiftExprs,
1494  rewriter.getContext()));
1495  shiftArg = indexingMaps.size() - 1;
1496  }
1497 
1498  // Indexing maps for output values.
1499  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1500 
1501  // Construct the indexing maps needed for linalg.generic ops.
1502  Value emptyTensor = tensor::EmptyOp::create(
1503  rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1504  ArrayRef<Value>({dynDims}));
1505 
1506  auto linalgOp = linalg::GenericOp::create(
1507  rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor},
1508  indexingMaps, getNParallelLoopsAttrs(rank),
1509  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1510  ValueRange blockArgs) {
1511  Value value = blockArgs[0];
1512  Type valueTy = value.getType();
1513 
1514  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1515  if (failed(maybeIZp)) {
1516  (void)rewriter.notifyMatchFailure(
1517  op, "input zero point cannot be statically determined");
1518  return;
1519  }
1520 
1521  const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1522  // Extend zeropoint for sub-32bits widths.
1523  const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1524  auto inputZp = arith::ConstantOp::create(
1525  nestedBuilder, loc,
1526  IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
1527  *maybeIZp));
1528 
1529  FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1530  if (failed(maybeOZp)) {
1531  (void)rewriter.notifyMatchFailure(
1532  op, "output zero point cannot be statically determined");
1533  return;
1534  };
1535 
1536  IntegerType outIntType =
1537  cast<IntegerType>(blockArgs.back().getType());
1538  unsigned outBitWidth = outIntType.getWidth();
1539  const int32_t outAttrBitwidth = 32;
1540  assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1541  auto outputZp = arith::ConstantOp::create(
1542  nestedBuilder, loc,
1543  IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
1544  *maybeOZp));
1545 
1546  Value multiplier = multiplierConstant ? multiplierConstant
1547  : blockArgs[multiplierArg];
1548  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1549 
1550  if (valueTy.isUnsignedInteger()) {
1551  value = UnrealizedConversionCastOp::create(
1552  nestedBuilder, nestedLoc,
1553  nestedBuilder.getIntegerType(
1554  valueTy.getIntOrFloatBitWidth()),
1555  value)
1556  .getResult(0);
1557  }
1558  if (valueTy.getIntOrFloatBitWidth() < 32) {
1559  if (op.getInputUnsigned()) {
1560  value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1561  nestedBuilder.getI32Type(), value);
1562  } else {
1563  value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1564  nestedBuilder.getI32Type(), value);
1565  }
1566  }
1567 
1568  value =
1569  arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1570 
1571  value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1572  nestedBuilder.getI32Type(), value,
1573  multiplier, shift, roundingMode);
1574 
1575  // Move to the new zero-point.
1576  value =
1577  arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1578 
1579  // Saturate to the output size.
1580  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1581  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1582 
1583  // Unsigned integers have a difference output value.
1584  if (op.getOutputUnsigned()) {
1585  intMin = 0;
1586  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1587  }
1588 
1589  auto intMinVal = arith::ConstantOp::create(
1590  nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1591  auto intMaxVal = arith::ConstantOp::create(
1592  nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1593 
1594  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1595  nestedBuilder, /*isUnsigned=*/false);
1596 
1597  if (outIntType.getWidth() < 32) {
1598  value = arith::TruncIOp::create(
1599  nestedBuilder, nestedLoc,
1600  rewriter.getIntegerType(outIntType.getWidth()), value);
1601  }
1602 
1603  if (outIntType.isUnsignedInteger()) {
1604  value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1605  outIntType, value)
1606  .getResult(0);
1607  }
1608  linalg::YieldOp::create(nestedBuilder, loc, value);
1609  });
1610 
1611  rewriter.replaceOp(op, linalgOp->getResults());
1612  return success();
1613  }
1614 };
1615 
1616 // Handle the resize case where the input is a 1x1 image. This case
1617 // can entirely avoiding having extract operations which target much
1618 // more difficult to optimize away.
1619 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1620 public:
1622 
1623  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1624  PatternRewriter &rewriter) const final {
1625  Location loc = op.getLoc();
1626  ImplicitLocOpBuilder builder(loc, rewriter);
1627  auto input = op.getInput();
1628  auto inputTy = cast<RankedTensorType>(input.getType());
1629  auto resultTy = cast<RankedTensorType>(op.getType());
1630  const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1631 
1632  auto inputH = inputTy.getDimSize(1);
1633  auto inputW = inputTy.getDimSize(2);
1634  auto outputH = resultTy.getDimSize(1);
1635  auto outputW = resultTy.getDimSize(2);
1636 
1637  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1638  return rewriter.notifyMatchFailure(
1639  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1640 
1641  if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1642  op.getMode() != ResizeMode::BILINEAR)
1643  return rewriter.notifyMatchFailure(
1644  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1645 
1646  if (inputTy == resultTy) {
1647  rewriter.replaceOp(op, input);
1648  return success();
1649  }
1650 
1651  SmallVector<int64_t> scale;
1652  if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
1653  return failure();
1654  }
1655 
1656  // Collapse the unit width and height away.
1657  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1658  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1659  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1660  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1661  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1662 
1663  auto collapseTy =
1664  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1665  inputTy.getElementType());
1666  Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1667  reassociationMap);
1668 
1669  // Get any dynamic shapes that appear in the input format.
1670  llvm::SmallVector<Value> outputDynSize;
1671  if (inputTy.isDynamicDim(0))
1672  outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1673  if (inputTy.isDynamicDim(3))
1674  outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1675 
1676  // Generate the elementwise operation for casting scaling the input value.
1677  auto genericTy = collapseTy.clone(resultTy.getElementType());
1678  Value empty =
1679  tensor::EmptyOp::create(builder, genericTy.getShape(),
1680  resultTy.getElementType(), outputDynSize);
1681  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1682  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1683  utils::IteratorType::parallel);
1684 
1685  auto generic = linalg::GenericOp::create(
1686  builder, genericTy, ValueRange{collapse}, ValueRange{empty},
1687  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1688  [=](OpBuilder &b, Location loc, ValueRange args) {
1689  Value value = args[0];
1690  // This is the quantized case.
1691  if (inputTy.getElementType() != resultTy.getElementType()) {
1692  value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
1693  value);
1694 
1695  if (isBilinear && scale[0] != 0) {
1696  Value scaleY = arith::ConstantOp::create(
1697  b, loc, b.getI32IntegerAttr(scale[0]));
1698  value = arith::MulIOp::create(b, loc, value, scaleY);
1699  }
1700 
1701  if (isBilinear && scale[2] != 0) {
1702  Value scaleX = arith::ConstantOp::create(
1703  b, loc, b.getI32IntegerAttr(scale[2]));
1704  value = arith::MulIOp::create(b, loc, value, scaleX);
1705  }
1706  }
1707 
1708  linalg::YieldOp::create(b, loc, value);
1709  });
1710 
1711  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1712  op, resultTy, generic.getResults()[0], reassociationMap);
1713  return success();
1714  }
1715 };
1716 
1717 // TOSA resize with width or height of 1 may be broadcasted to a wider
1718 // dimension. This is done by materializing a new tosa.resize without
1719 // the broadcasting behavior, and an explicit broadcast afterwards.
1720 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1721 public:
1723 
1724  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1725  PatternRewriter &rewriter) const final {
1726  Location loc = op.getLoc();
1727  ImplicitLocOpBuilder builder(loc, rewriter);
1728  auto input = op.getInput();
1729  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1730  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1731 
1732  if (!inputTy || !resultTy)
1733  return rewriter.notifyMatchFailure(op,
1734  "requires ranked input/output types");
1735 
1736  auto batch = inputTy.getDimSize(0);
1737  auto channels = inputTy.getDimSize(3);
1738  auto inputH = inputTy.getDimSize(1);
1739  auto inputW = inputTy.getDimSize(2);
1740  auto outputH = resultTy.getDimSize(1);
1741  auto outputW = resultTy.getDimSize(2);
1742 
1743  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1744  return rewriter.notifyMatchFailure(
1745  op, "tosa.resize has no broadcasting behavior");
1746 
1747  // For any dimension that is broadcastable we generate a width of 1
1748  // on the output.
1749  llvm::SmallVector<int64_t> resizeShape;
1750  resizeShape.push_back(batch);
1751  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1752  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1753  resizeShape.push_back(channels);
1754 
1755  auto resizeTy = resultTy.clone(resizeShape);
1756  auto resize =
1757  tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1758  op.getOffset(), op.getBorder(), op.getMode());
1759 
1760  // Collapse an unit result dims.
1761  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1762  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1763  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1764  if (inputH != 1)
1765  reassociationMap.push_back({});
1766  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1767  if (inputW != 1)
1768  reassociationMap.push_back({});
1769  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1770 
1771  llvm::SmallVector<int64_t> collapseShape = {batch};
1772  if (inputH != 1)
1773  collapseShape.push_back(outputH);
1774  if (inputW != 1)
1775  collapseShape.push_back(outputW);
1776  collapseShape.push_back(channels);
1777 
1778  auto collapseTy = resultTy.clone(collapseShape);
1779  Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1780  resize, reassociationMap);
1781 
1782  // Broadcast the collapsed shape to the output result.
1783  llvm::SmallVector<Value> outputDynSize;
1784  if (inputTy.isDynamicDim(0))
1785  outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1786  if (inputTy.isDynamicDim(3))
1787  outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1788 
1789  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1790  utils::IteratorType::parallel);
1791  Value empty = tensor::EmptyOp::create(
1792  builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1793 
1794  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1795  if (inputH != 1)
1796  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1797  if (inputW != 1)
1798  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1799  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1800 
1801  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1802  inputExprs, rewriter.getContext());
1803 
1804  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1805  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1806  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1807  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1808  [=](OpBuilder &b, Location loc, ValueRange args) {
1809  Value value = args[0];
1810  linalg::YieldOp::create(b, loc, value);
1811  });
1812 
1813  return success();
1814  }
1815 };
1816 
1817 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1818 public:
1820 
1821  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1822  PatternRewriter &rewriter) const final {
1823  Location loc = op.getLoc();
1824  ImplicitLocOpBuilder b(loc, rewriter);
1825  auto input = op.getInput();
1826  auto inputTy = cast<ShapedType>(input.getType());
1827  auto resultTy = cast<ShapedType>(op.getType());
1828  auto resultETy = resultTy.getElementType();
1829 
1830  bool floatingPointMode = isa<FloatType>(resultETy);
1831  auto floatTy = resultETy;
1832 
1833  auto imageH = inputTy.getShape()[1];
1834  auto imageW = inputTy.getShape()[2];
1835 
1836  auto dynamicDimsOr =
1837  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1838  if (!dynamicDimsOr.has_value())
1839  return rewriter.notifyMatchFailure(
1840  op, "unable to get dynamic dimensions of tosa.resize");
1841 
1842  if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1843  op.getMode() != ResizeMode::BILINEAR)
1844  return rewriter.notifyMatchFailure(
1845  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1846 
1847  SmallVector<AffineMap, 2> affineMaps = {
1848  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1849  auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
1850  resultETy, *dynamicDimsOr);
1851  auto genericOp = linalg::GenericOp::create(
1852  b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1853  getNParallelLoopsAttrs(resultTy.getRank()));
1854  Value resize = genericOp.getResult(0);
1855 
1856  {
1857  OpBuilder::InsertionGuard regionGuard(b);
1858  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1859  TypeRange({resultETy}), loc);
1860  Value batch = linalg::IndexOp::create(b, 0);
1861  Value y = linalg::IndexOp::create(b, 1);
1862  Value x = linalg::IndexOp::create(b, 2);
1863  Value channel = linalg::IndexOp::create(b, 3);
1864 
1865  Value zeroI32 =
1866  arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
1867  Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
1868  Value hMax =
1869  arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
1870  Value wMax =
1871  arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
1872 
1873  Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
1874  Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
1875 
1876  SmallVector<int64_t> scale, offset, border;
1877  if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
1878  !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
1879  !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
1880  return rewriter.notifyMatchFailure(
1881  op, "tosa.resize scale/offset/border should have compile time "
1882  "constant values.");
1883  }
1884 
1885  Value yScaleN, yScaleD, xScaleN, xScaleD;
1886  yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
1887  yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
1888  xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
1889  xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
1890 
1891  Value yOffset, xOffset, yBorder, xBorder;
1892  yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
1893  xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
1894  yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
1895  xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
1896 
1897  // Compute the ix and dx values for both the X and Y dimensions.
1898  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1899  Value scaleN, Value scaleD, Value offset,
1900  int size, ImplicitLocOpBuilder &b) {
1901  if (size == 1) {
1902  index = zeroI32;
1903  delta = zeroFp;
1904  return;
1905  }
1906  // x = x * scale_d + offset;
1907  // ix = floor(x / scale_n)
1908  Value val = arith::MulIOp::create(b, in, scaleD);
1909  val = arith::AddIOp::create(b, val, offset);
1910  index = arith::FloorDivSIOp::create(b, val, scaleN);
1911 
1912  // rx = x % scale_n
1913  // dx = rx / scale_n
1914  Value r = arith::RemSIOp::create(b, val, scaleN);
1915  Value rFp = arith::SIToFPOp::create(b, floatTy, r);
1916  Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
1917  delta = arith::DivFOp::create(b, rFp, scaleNfp);
1918  };
1919 
1920  // Compute the ix and dx values for the X and Y dimensions - int case.
1921  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1922  Value scaleN, Value scaleD, Value offset,
1923  int size, ImplicitLocOpBuilder &b) {
1924  if (size == 1) {
1925  index = zeroI32;
1926  delta = zeroI32;
1927  return;
1928  }
1929  // x = x * scale_d + offset;
1930  // ix = floor(x / scale_n)
1931  // dx = x - ix * scale_n;
1932  Value val = arith::MulIOp::create(b, in, scaleD);
1933  val = arith::AddIOp::create(b, val, offset);
1934  index = arith::DivSIOp::create(b, val, scaleN);
1935  delta = arith::MulIOp::create(b, index, scaleN);
1936  delta = arith::SubIOp::create(b, val, delta);
1937  };
1938 
1939  Value ix, iy, dx, dy;
1940  if (floatingPointMode) {
1941  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1942  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1943  } else {
1944  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1945  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1946  }
1947 
1948  if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
1949  auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
1950 
1951  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1952  Value max, int size,
1953  ImplicitLocOpBuilder &b) -> Value {
1954  if (size == 1) {
1955  return arith::ConstantIndexOp::create(b, 0);
1956  }
1957 
1958  Value pred;
1959  if (floatingPointMode) {
1960  auto h =
1961  arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
1962  pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
1963  } else {
1964  Value dvalDouble = arith::ShLIOp::create(b, dval, one);
1965  pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
1966  dvalDouble, scale);
1967  }
1968 
1969  auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
1970  val = arith::AddIOp::create(b, val, offset);
1971  val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
1972  return arith::IndexCastOp::create(b, b.getIndexType(), val);
1973  };
1974 
1975  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1976  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1977 
1978  Value result = tensor::ExtractOp::create(
1979  b, input, ValueRange{batch, iy, ix, channel});
1980 
1981  linalg::YieldOp::create(b, result);
1982  } else {
1983  // The mode here must be BILINEAR.
1984  assert(op.getMode() == ResizeMode::BILINEAR);
1985 
1986  auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
1987 
1988  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1990  val0 = in;
1991  val1 = arith::AddIOp::create(b, val0, oneVal);
1992  val0 =
1993  clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1994  val1 =
1995  clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
1996  val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
1997  val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
1998  };
1999 
2000  // Linalg equivalent to the section below:
2001  // int16_t iy0 = apply_max(iy, 0);
2002  // int16_t iy1 = apply_min(iy + 1, IH - 1);
2003  // int16_t ix0 = apply_max(ix, 0);
2004  // int16_t ix1 = apply_min(ix + 1, IW - 1);
2005  Value x0, x1, y0, y1;
2006  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
2007  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
2008 
2009  Value y0x0 = tensor::ExtractOp::create(
2010  b, input, ValueRange{batch, y0, x0, channel});
2011  Value y0x1 = tensor::ExtractOp::create(
2012  b, input, ValueRange{batch, y0, x1, channel});
2013  Value y1x0 = tensor::ExtractOp::create(
2014  b, input, ValueRange{batch, y1, x0, channel});
2015  Value y1x1 = tensor::ExtractOp::create(
2016  b, input, ValueRange{batch, y1, x1, channel});
2017 
2018  if (floatingPointMode) {
2019  auto oneVal =
2020  arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
2021  auto interpolate = [&](Value val0, Value val1, Value delta,
2022  int inputSize,
2023  ImplicitLocOpBuilder &b) -> Value {
2024  if (inputSize == 1)
2025  return val0;
2026  Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
2027  Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
2028  Value mul1 = arith::MulFOp::create(b, val1, delta);
2029  return arith::AddFOp::create(b, mul0, mul1);
2030  };
2031 
2032  // Linalg equivalent to the section below:
2033  // topAcc = v00 * (unit_x - dx);
2034  // topAcc += v01 * dx;
2035  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
2036 
2037  // Linalg equivalent to the section below:
2038  // bottomAcc = v10 * (unit_x - dx);
2039  // bottomAcc += v11 * dx;
2040  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
2041 
2042  // Linalg equivalent to the section below:
2043  // result = topAcc * (unit_y - dy) + bottomAcc * dy
2044  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
2045  linalg::YieldOp::create(b, result);
2046  } else {
2047  // Perform in quantized space.
2048  y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
2049  y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
2050  y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
2051  y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
2052 
2053  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
2054  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2055  dx = arith::ExtSIOp::create(b, resultETy, dx);
2056  dy = arith::ExtSIOp::create(b, resultETy, dy);
2057  }
2058 
2059  Value yScaleNExt = yScaleN;
2060  Value xScaleNExt = xScaleN;
2061 
2062  const int64_t scaleBitwidth =
2063  xScaleN.getType().getIntOrFloatBitWidth();
2064  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2065  yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
2066  xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
2067  }
2068 
2069  auto interpolate = [](Value val0, Value val1, Value weight1,
2070  Value scale, int inputSize,
2071  ImplicitLocOpBuilder &b) -> Value {
2072  if (inputSize == 1)
2073  return arith::MulIOp::create(b, val0, scale);
2074  Value weight0 = arith::SubIOp::create(b, scale, weight1);
2075  Value mul0 = arith::MulIOp::create(b, val0, weight0);
2076  Value mul1 = arith::MulIOp::create(b, val1, weight1);
2077  return arith::AddIOp::create(b, mul0, mul1);
2078  };
2079 
2080  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2081  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2082  Value result =
2083  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2084  linalg::YieldOp::create(b, result);
2085  }
2086  }
2087  }
2088 
2089  rewriter.replaceOp(op, resize);
2090  return success();
2091  }
2092 };
2093 
2094 // At the codegen level any identity operations should be removed. Any cases
2095 // where identity is load-bearing (e.g. cross device computation) should be
2096 // handled before lowering to codegen.
2097 template <typename SrcOp>
2098 class IdentityNConverter : public OpRewritePattern<SrcOp> {
2099 public:
2101 
2102  LogicalResult matchAndRewrite(SrcOp op,
2103  PatternRewriter &rewriter) const final {
2104  rewriter.replaceOp(op, op.getOperation()->getOperands());
2105  return success();
2106  }
2107 };
2108 
2109 template <typename SrcOp>
2110 class ReduceConverter : public OpRewritePattern<SrcOp> {
2111 public:
2113 
2114  LogicalResult matchAndRewrite(SrcOp reduceOp,
2115  PatternRewriter &rewriter) const final {
2116  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2117  }
2118 };
2119 
2120 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2121 public:
2123 
2124  LogicalResult matchAndRewrite(tosa::ReverseOp op,
2125  PatternRewriter &rewriter) const final {
2126  auto loc = op.getLoc();
2127  Value input = op.getInput1();
2128  auto inputTy = cast<ShapedType>(input.getType());
2129  auto resultTy = cast<ShapedType>(op.getType());
2130  auto axis = op.getAxis();
2131 
2132  SmallVector<Value> dynDims;
2133  for (int i = 0; i < inputTy.getRank(); i++) {
2134  if (inputTy.isDynamicDim(i)) {
2135  dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2136  }
2137  }
2138 
2139  Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2140 
2141  // First fill the output buffer with the init value.
2142  auto emptyTensor = tensor::EmptyOp::create(
2143  rewriter, loc, inputTy.getShape(),
2144  inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2145  .getResult();
2146  SmallVector<AffineMap, 2> affineMaps = {
2147  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2148 
2149  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2150  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
2151  getNParallelLoopsAttrs(resultTy.getRank()),
2152  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2153  llvm::SmallVector<Value> indices;
2154  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2155  Value index =
2156  linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2157  if (i == axis) {
2158  auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1);
2159  auto sizeMinusOne =
2160  arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2161  index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2162  index);
2163  }
2164 
2165  indices.push_back(index);
2166  }
2167 
2168  auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2169  input, indices);
2170  linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2171  extract.getResult());
2172  });
2173  return success();
2174  }
2175 };
2176 
2177 // This converter translate a tile operation to a reshape, broadcast, reshape.
2178 // The first reshape minimally expands each tiled dimension to include a
2179 // proceding size-1 dim. This dim is then broadcasted to the appropriate
2180 // multiple.
2181 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2183 
2184  LogicalResult
2185  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2186  ConversionPatternRewriter &rewriter) const override {
2187  auto loc = op.getLoc();
2188  auto input = op.getInput1();
2189  auto inputTy = cast<ShapedType>(input.getType());
2190  auto inputShape = inputTy.getShape();
2191  auto resultTy = cast<ShapedType>(op.getType());
2192  auto elementTy = inputTy.getElementType();
2193  int64_t rank = inputTy.getRank();
2194 
2195  SmallVector<int64_t> multiples;
2196  if (failed(op.getConstantMultiples(multiples)))
2197  return failure();
2198 
2199  // Broadcast the newly added dimensions to their appropriate multiple.
2200  SmallVector<int64_t, 2> genericShape;
2201  for (int i = 0; i < rank; i++) {
2202  int64_t dim = multiples[i];
2203  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2204  genericShape.push_back(inputShape[i]);
2205  }
2206 
2207  SmallVector<Value> dynDims;
2208  for (int i = 0; i < inputTy.getRank(); i++) {
2209  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2210  dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2211  }
2212  }
2213 
2214  auto emptyTensor = tensor::EmptyOp::create(
2215  rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2216 
2217  // We needs to map the input shape to the non-broadcasted dimensions.
2218  SmallVector<AffineExpr, 4> dimExprs;
2219  dimExprs.reserve(rank);
2220  for (unsigned i = 0; i < rank; ++i)
2221  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2222 
2223  auto readAffineMap =
2224  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2225  rewriter.getContext());
2226 
2227  SmallVector<AffineMap, 2> affineMaps = {
2228  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2229 
2230  auto genericOp = linalg::GenericOp::create(
2231  rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2232  ValueRange{emptyTensor}, affineMaps,
2233  getNParallelLoopsAttrs(genericShape.size()),
2234  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2235  linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2236  });
2237 
2238  auto shapeValue = getTosaConstShape(
2239  rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
2240  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2241  op, resultTy, genericOp.getResult(0), shapeValue);
2242  return success();
2243  }
2244 };
2245 
2246 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2247 // op, producing two output buffers.
2248 //
2249 // The first output buffer contains the index of the found maximum value. It is
2250 // initialized to 0 and is resulting integer type.
2251 //
2252 // The second output buffer contains the maximum value found. It is initialized
2253 // to the minimum representable value of the input element type. After being
2254 // populated by indexed_generic, this buffer is disgarded as only the index is
2255 // requested.
2256 //
2257 // The indexed_generic op updates both the maximum value and index if the
2258 // current value exceeds the running max.
2259 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2260 public:
2262 
2263  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2264  PatternRewriter &rewriter) const final {
2265  auto loc = argmaxOp.getLoc();
2266  Value input = argmaxOp.getInput();
2267  auto inputTy = cast<ShapedType>(input.getType());
2268  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2269  auto inElementTy = inputTy.getElementType();
2270  auto outElementTy = resultTy.getElementType();
2271  int axis = argmaxOp.getAxis();
2272  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2273 
2274  if (!isa<IntegerType>(outElementTy))
2275  return rewriter.notifyMatchFailure(
2276  argmaxOp,
2277  "tosa.arg_max to linalg.* requires integer-like result type");
2278 
2279  SmallVector<Value> dynDims;
2280  for (int i = 0; i < inputTy.getRank(); i++) {
2281  if (inputTy.isDynamicDim(i) && i != axis) {
2282  dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2283  }
2284  }
2285 
2286  // First fill the output buffer for the index.
2287  auto emptyTensorIdx =
2288  tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2289  outElementTy, dynDims)
2290  .getResult();
2291  auto fillValueIdx = arith::ConstantOp::create(
2292  rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2293  auto filledTensorIdx =
2294  linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
2295  ValueRange{emptyTensorIdx})
2296  .result();
2297 
2298  // Second fill the output buffer for the running max.
2299  auto emptyTensorMax =
2300  tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2301  dynDims)
2302  .getResult();
2303  auto fillValueMaxAttr =
2304  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2305 
2306  if (!fillValueMaxAttr)
2307  return rewriter.notifyMatchFailure(
2308  argmaxOp, "unsupported tosa.argmax element type");
2309 
2310  auto fillValueMax =
2311  arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2312  auto filledTensorMax =
2313  linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
2314  ValueRange{emptyTensorMax})
2315  .result();
2316 
2317  // We need to reduce along the arg-max axis, with parallel operations along
2318  // the rest.
2320  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2321  iteratorTypes[axis] = utils::IteratorType::reduction;
2322 
2323  SmallVector<AffineExpr, 2> srcExprs;
2324  SmallVector<AffineExpr, 2> dstExprs;
2325  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2326  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2327  if (axis != i)
2328  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2329  }
2330 
2331  bool didEncounterError = false;
2332  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2333  rewriter.getContext());
2334  auto linalgOp = linalg::GenericOp::create(
2335  rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2336  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2337  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2338  ValueRange blockArgs) {
2339  auto newValue = blockArgs[0];
2340  auto oldIndex = blockArgs[1];
2341  auto oldValue = blockArgs[2];
2342 
2343  Value newIndex = arith::IndexCastOp::create(
2344  rewriter, nestedLoc, oldIndex.getType(),
2345  linalg::IndexOp::create(rewriter, loc, axis));
2346 
2347  Value predicate;
2348  if (isa<FloatType>(inElementTy)) {
2349  if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2350  // Only update index & max value for non NaN values. If all
2351  // values are NaNs, the initial index will be return which is 0.
2352  predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2353  arith::CmpFPredicate::OGT,
2354  newValue, oldValue);
2355  } else {
2356  // Update max value if either of the following is true:
2357  // - new value is bigger
2358  // - cur max is not NaN and new value is NaN
2359  Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2360  arith::CmpFPredicate::UGT,
2361  newValue, oldValue);
2362  Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2363  arith::CmpFPredicate::ORD,
2364  oldValue, oldValue);
2365  predicate = arith::AndIOp::create(
2366  rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2367  }
2368  } else if (isa<IntegerType>(inElementTy)) {
2369  predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2370  arith::CmpIPredicate::sgt,
2371  newValue, oldValue);
2372  } else {
2373  didEncounterError = true;
2374  return;
2375  }
2376 
2377  auto resultMax = arith::SelectOp::create(
2378  rewriter, nestedLoc, predicate, newValue, oldValue);
2379  auto resultIndex = arith::SelectOp::create(
2380  rewriter, nestedLoc, predicate, newIndex, oldIndex);
2381  linalg::YieldOp::create(nestedBuilder, nestedLoc,
2382  ValueRange({resultIndex, resultMax}));
2383  });
2384 
2385  if (didEncounterError)
2386  return rewriter.notifyMatchFailure(
2387  argmaxOp, "unsupported tosa.argmax element type");
2388 
2389  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2390  return success();
2391  }
2392 };
2393 
2394 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2395 public:
2397  LogicalResult
2398  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2399  ConversionPatternRewriter &rewriter) const final {
2400  auto input = adaptor.getOperands()[0];
2401  auto indices = adaptor.getOperands()[1];
2402 
2403  auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2404  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2405  if (!valuesTy || !resultTy)
2406  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2407 
2408  auto dynamicDims = inferDynamicDimsForGather(
2409  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2410 
2411  auto resultElementTy = resultTy.getElementType();
2412 
2413  auto loc = op.getLoc();
2414  auto emptyTensor =
2415  tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2416  resultElementTy, dynamicDims)
2417  .getResult();
2418 
2419  SmallVector<AffineMap, 2> affineMaps = {
2421  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2422  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2423  rewriter.getContext()),
2424  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2425 
2426  auto genericOp = linalg::GenericOp::create(
2427  rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2428  ValueRange{emptyTensor}, affineMaps,
2429  getNParallelLoopsAttrs(resultTy.getRank()),
2430  [&](OpBuilder &b, Location loc, ValueRange args) {
2431  auto indexValue = args[0];
2432  auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2433  Value index1 = arith::IndexCastOp::create(
2434  rewriter, loc, rewriter.getIndexType(), indexValue);
2435  auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2436  Value extract = tensor::ExtractOp::create(
2437  rewriter, loc, input, ValueRange{index0, index1, index2});
2438  linalg::YieldOp::create(rewriter, loc, extract);
2439  });
2440  rewriter.replaceOp(op, genericOp.getResult(0));
2441  return success();
2442  }
2443 
2444  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2445  Location loc,
2446  Value values,
2447  Value indices) {
2448  llvm::SmallVector<Value> results;
2449 
2450  auto addDynamicDimension = [&](Value source, int64_t dim) {
2451  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2452  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2453  results.push_back(dimValue);
2454  };
2455 
2456  addDynamicDimension(values, 0);
2457  addDynamicDimension(indices, 1);
2458  addDynamicDimension(values, 2);
2459  return results;
2460  }
2461 };
2462 
2463 // Lowerings the TableOp to a series of gathers and numerica operations. This
2464 // includes interpolation between the high/low values. For the I8 varient, this
2465 // simplifies to a single gather operation.
2466 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2467 public:
2469 
2470  LogicalResult matchAndRewrite(tosa::TableOp op,
2471  PatternRewriter &rewriter) const final {
2472  auto loc = op.getLoc();
2473  Value input = op.getInput1();
2474  Value table = op.getTable();
2475  auto inputTy = cast<ShapedType>(input.getType());
2476  auto tableTy = cast<ShapedType>(table.getType());
2477  auto resultTy = cast<ShapedType>(op.getType());
2478 
2479  auto inputElementTy = inputTy.getElementType();
2480  auto tableElementTy = tableTy.getElementType();
2481  auto resultElementTy = resultTy.getElementType();
2482 
2483  SmallVector<Value> dynDims;
2484  for (int i = 0; i < resultTy.getRank(); ++i) {
2485  if (inputTy.isDynamicDim(i)) {
2486  dynDims.push_back(
2487  tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2488  }
2489  }
2490 
2491  auto emptyTensor =
2492  tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2493  resultElementTy, dynDims)
2494  .getResult();
2495 
2496  SmallVector<AffineMap, 2> affineMaps = {
2497  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2498  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2499 
2500  auto genericOp = linalg::GenericOp::create(
2501  rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor},
2502  affineMaps, getNParallelLoopsAttrs(resultTy.getRank()));
2503  rewriter.replaceOp(op, genericOp.getResult(0));
2504 
2505  {
2506  OpBuilder::InsertionGuard regionGuard(rewriter);
2507  Block *block = rewriter.createBlock(
2508  &genericOp.getRegion(), genericOp.getRegion().end(),
2509  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2510 
2511  auto inputValue = block->getArgument(0);
2512  rewriter.setInsertionPointToStart(block);
2513  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2514  resultElementTy.isInteger(8)) {
2515  Value index = arith::IndexCastOp::create(
2516  rewriter, loc, rewriter.getIndexType(), inputValue);
2517  Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128);
2518  index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2519  index, offset);
2520  Value extract =
2521  tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2522  linalg::YieldOp::create(rewriter, loc, extract);
2523  return success();
2524  }
2525 
2526  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2527  resultElementTy.isInteger(32)) {
2528  Value extend = arith::ExtSIOp::create(
2529  rewriter, loc, rewriter.getI32Type(), inputValue);
2530 
2531  auto offset = arith::ConstantOp::create(
2532  rewriter, loc, rewriter.getI32IntegerAttr(32768));
2533  auto seven = arith::ConstantOp::create(rewriter, loc,
2534  rewriter.getI32IntegerAttr(7));
2535  auto one = arith::ConstantOp::create(rewriter, loc,
2536  rewriter.getI32IntegerAttr(1));
2537  auto b1111111 = arith::ConstantOp::create(
2538  rewriter, loc, rewriter.getI32IntegerAttr(127));
2539 
2540  // Compute the index and fractional part from the input value:
2541  // value = value + 32768
2542  // index = value >> 7;
2543  // fraction = 0x01111111 & value
2544  auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2545  Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2546  Value fraction =
2547  arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2548 
2549  // Extract the base and next values from the table.
2550  // base = (int32_t) table[index];
2551  // next = (int32_t) table[index + 1];
2552  Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2553 
2554  index = arith::IndexCastOp::create(rewriter, loc,
2555  rewriter.getIndexType(), index);
2556  indexPlusOne = arith::IndexCastOp::create(
2557  rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2558 
2559  Value base =
2560  tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2561  Value next = tensor::ExtractOp::create(rewriter, loc, table,
2562  ValueRange{indexPlusOne});
2563 
2564  base =
2565  arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2566  next =
2567  arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2568 
2569  // Use the fractional part to interpolate between the input values:
2570  // result = (base << 7) + (next - base) * fraction
2571  Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2572  Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2573  Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2574  Value result =
2575  arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2576 
2577  linalg::YieldOp::create(rewriter, loc, result);
2578 
2579  return success();
2580  }
2581  }
2582 
2583  return rewriter.notifyMatchFailure(
2584  op, "unable to create body for tosa.table op");
2585  }
2586 };
2587 
2588 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2590 
2591  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2592 
2593  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2594  OpFoldResult ofr) {
2595  auto one = arith::ConstantIndexOp::create(builder, loc, 1);
2596  auto two = arith::ConstantIndexOp::create(builder, loc, 2);
2597 
2598  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2599  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2600  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2601  return getAsOpFoldResult(plusOne);
2602  }
2603 
2604  static RankedTensorType
2605  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2606  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2607  // Get [N, H, W]
2608  auto dims = tensor::getMixedSizes(builder, loc, input);
2609 
2610  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2611  // output tensors.
2612  dims[2] = halfPlusOne(builder, loc, dims[2]);
2613 
2614  llvm::SmallVector<int64_t, 3> staticSizes;
2615  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2616 
2617  auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2618  return RankedTensorType::get(staticSizes, elementType);
2619  }
2620 
2621  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2622  RankedTensorType type,
2623  llvm::ArrayRef<Value> dynamicSizes) {
2624  auto emptyTensor =
2625  tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2626  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2627  auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2628  auto filledTensor =
2629  linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
2630  ValueRange{emptyTensor})
2631  .result();
2632  return filledTensor;
2633  }
2634 
2635  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2636  FloatType type, Value value) {
2637  auto integerVal = arith::IndexCastUIOp::create(
2638  builder, loc,
2639  type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2640  : builder.getI32Type(),
2641  value);
2642 
2643  return arith::UIToFPOp::create(builder, loc, type, integerVal);
2644  }
2645 
2646  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2647  FloatType type, int64_t index) {
2648  auto indexVal = linalg::IndexOp::create(builder, loc, index);
2649  return castIndexToFloat(builder, loc, type, indexVal);
2650  }
2651 
2652  template <typename... Args>
2653  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2654  Args... args) {
2655  return {builder.getAffineDimExpr(args)...};
2656  }
2657 
2658  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2659  PatternRewriter &rewriter) const override {
2660  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2661  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2662  return rewriter.notifyMatchFailure(rfft2d,
2663  "only supports ranked tensors");
2664  }
2665 
2666  auto loc = rfft2d.getLoc();
2667  auto input = rfft2d.getInputReal();
2668  auto elementType =
2669  dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2670  if (!elementType)
2671  return rewriter.notifyMatchFailure(rfft2d,
2672  "only supports float element types");
2673 
2674  // Compute the output type and set of dynamic sizes
2675  llvm::SmallVector<Value> dynamicSizes;
2676  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2677 
2678  // Iterator types for the linalg.generic implementation
2680  utils::IteratorType::parallel, utils::IteratorType::parallel,
2681  utils::IteratorType::parallel, utils::IteratorType::reduction,
2682  utils::IteratorType::reduction};
2683 
2684  // Inputs/outputs to the linalg.generic implementation
2685  llvm::SmallVector<Value> genericOpInputs = {input};
2686  llvm::SmallVector<Value> genericOpOutputs = {
2687  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2688  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2689 
2690  // Indexing maps for input and output tensors
2691  auto indexingMaps = AffineMap::inferFromExprList(
2692  llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2693  affineDimsExpr(rewriter, 0, 1, 2),
2694  affineDimsExpr(rewriter, 0, 1, 2)},
2695  rewriter.getContext());
2696 
2697  // Width and height dimensions of the original input.
2698  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2699  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2700 
2701  // Constants and dimension sizes
2702  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2703  auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2704  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2705  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2706 
2707  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2708  Value valReal = args[0];
2709  Value sumReal = args[1];
2710  Value sumImag = args[2];
2711 
2712  // Indices for angle computation
2713  Value oy = linalg::IndexOp::create(builder, loc, 1);
2714  Value ox = linalg::IndexOp::create(builder, loc, 2);
2715  Value iy = linalg::IndexOp::create(builder, loc, 3);
2716  Value ix = linalg::IndexOp::create(builder, loc, 4);
2717 
2718  // Calculating angle without integer parts of components as sin/cos are
2719  // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2720  // / W);
2721  auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2722  auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2723 
2724  auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2725  auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2726 
2727  auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2728  auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2729 
2730  auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2731  auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2732  auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2733  auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2734 
2735  // realComponent = valReal * cos(angle)
2736  // imagComponent = valReal * sin(angle)
2737  auto cosAngle = math::CosOp::create(builder, loc, angle);
2738  auto sinAngle = math::SinOp::create(builder, loc, angle);
2739  auto realComponent =
2740  arith::MulFOp::create(builder, loc, valReal, cosAngle);
2741  auto imagComponent =
2742  arith::MulFOp::create(builder, loc, valReal, sinAngle);
2743 
2744  // outReal = sumReal + realComponent
2745  // outImag = sumImag - imagComponent
2746  auto outReal =
2747  arith::AddFOp::create(builder, loc, sumReal, realComponent);
2748  auto outImag =
2749  arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2750 
2751  linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
2752  };
2753 
2754  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2755  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2756  indexingMaps, iteratorTypes, buildBody);
2757 
2758  return success();
2759  }
2760 };
2761 
2762 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2764 
2765  LogicalResult matchAndRewrite(FFT2dOp fft2d,
2766  PatternRewriter &rewriter) const override {
2767  if (!llvm::all_of(fft2d->getOperandTypes(),
2768  RFFT2dConverter::isRankedTensor) ||
2769  !llvm::all_of(fft2d->getResultTypes(),
2770  RFFT2dConverter::isRankedTensor)) {
2771  return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2772  }
2773 
2774  Location loc = fft2d.getLoc();
2775  Value input_real = fft2d.getInputReal();
2776  Value input_imag = fft2d.getInputImag();
2777  BoolAttr inverse = fft2d.getInverseAttr();
2778 
2779  auto real_el_ty = cast<FloatType>(
2780  cast<ShapedType>(input_real.getType()).getElementType());
2781  [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2782  cast<ShapedType>(input_imag.getType()).getElementType());
2783 
2784  assert(real_el_ty == imag_el_ty);
2785 
2786  // Compute the output type and set of dynamic sizes
2787  SmallVector<Value> dynamicSizes;
2788 
2789  // Get [N, H, W]
2790  auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2791 
2792  SmallVector<int64_t, 3> staticSizes;
2793  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2794 
2795  auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2796 
2797  // Iterator types for the linalg.generic implementation
2798  SmallVector<utils::IteratorType, 5> iteratorTypes = {
2799  utils::IteratorType::parallel, utils::IteratorType::parallel,
2800  utils::IteratorType::parallel, utils::IteratorType::reduction,
2801  utils::IteratorType::reduction};
2802 
2803  // Inputs/outputs to the linalg.generic implementation
2804  SmallVector<Value> genericOpInputs = {input_real, input_imag};
2805  SmallVector<Value> genericOpOutputs = {
2806  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2807  dynamicSizes),
2808  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2809  dynamicSizes)};
2810 
2811  // Indexing maps for input and output tensors
2812  auto indexingMaps = AffineMap::inferFromExprList(
2813  ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2814  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2815  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2816  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2817  rewriter.getContext());
2818 
2819  // Width and height dimensions of the original input.
2820  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2821  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2822 
2823  // Constants and dimension sizes
2824  auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2825  auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2826  Value constH =
2827  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2828  Value constW =
2829  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2830 
2831  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2832  Value valReal = args[0];
2833  Value valImag = args[1];
2834  Value sumReal = args[2];
2835  Value sumImag = args[3];
2836 
2837  // Indices for angle computation
2838  Value oy = linalg::IndexOp::create(builder, loc, 1);
2839  Value ox = linalg::IndexOp::create(builder, loc, 2);
2840  Value iy = linalg::IndexOp::create(builder, loc, 3);
2841  Value ix = linalg::IndexOp::create(builder, loc, 4);
2842 
2843  // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2844  // ox) % W ) / W);
2845  auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2846  auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2847 
2848  auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2849  auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2850 
2851  auto iyRemFloat =
2852  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2853  auto ixRemFloat =
2854  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2855 
2856  auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2857  auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2858 
2859  auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2860  auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2861 
2862  if (inverse.getValue()) {
2863  angle = arith::MulFOp::create(
2864  builder, loc, angle,
2865  arith::ConstantOp::create(rewriter, loc,
2866  rewriter.getFloatAttr(real_el_ty, -1.0)));
2867  }
2868 
2869  // realComponent = val_real * cos(a) + val_imag * sin(a);
2870  // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2871  auto cosAngle = math::CosOp::create(builder, loc, angle);
2872  auto sinAngle = math::SinOp::create(builder, loc, angle);
2873 
2874  auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
2875  auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
2876  auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
2877 
2878  auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
2879  auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
2880 
2881  auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
2882 
2883  // outReal = sumReal + realComponent
2884  // outImag = sumImag - imagComponent
2885  auto outReal =
2886  arith::AddFOp::create(builder, loc, sumReal, realComponent);
2887  auto outImag =
2888  arith::AddFOp::create(builder, loc, sumImag, imagComponent);
2889 
2890  linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
2891  };
2892 
2893  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2894  fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2895  indexingMaps, iteratorTypes, buildBody);
2896 
2897  return success();
2898  }
2899 };
2900 
2901 } // namespace
2902 
2904  const TypeConverter &converter, RewritePatternSet *patterns) {
2905 
2906  // We have multiple resize coverters to handle degenerate cases.
2907  patterns->add<GenericResizeConverter>(patterns->getContext(),
2908  /*benefit=*/100);
2909  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2910  /*benefit=*/200);
2911  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2912  /*benefit=*/300);
2913 
2914  patterns->add<
2915  // clang-format off
2916  PointwiseConverter<tosa::AddOp>,
2917  PointwiseConverter<tosa::SubOp>,
2918  PointwiseConverter<tosa::MulOp>,
2919  PointwiseConverter<tosa::IntDivOp>,
2920  PointwiseConverter<tosa::NegateOp>,
2921  PointwiseConverter<tosa::PowOp>,
2922  PointwiseConverter<tosa::ReciprocalOp>,
2923  PointwiseConverter<tosa::RsqrtOp>,
2924  PointwiseConverter<tosa::LogOp>,
2925  PointwiseConverter<tosa::ExpOp>,
2926  PointwiseConverter<tosa::AbsOp>,
2927  PointwiseConverter<tosa::SinOp>,
2928  PointwiseConverter<tosa::CosOp>,
2929  PointwiseConverter<tosa::TanhOp>,
2930  PointwiseConverter<tosa::ErfOp>,
2931  PointwiseConverter<tosa::BitwiseAndOp>,
2932  PointwiseConverter<tosa::BitwiseOrOp>,
2933  PointwiseConverter<tosa::BitwiseNotOp>,
2934  PointwiseConverter<tosa::BitwiseXorOp>,
2935  PointwiseConverter<tosa::LogicalAndOp>,
2936  PointwiseConverter<tosa::LogicalNotOp>,
2937  PointwiseConverter<tosa::LogicalOrOp>,
2938  PointwiseConverter<tosa::LogicalXorOp>,
2939  PointwiseConverter<tosa::CastOp>,
2940  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2941  PointwiseConverter<tosa::LogicalRightShiftOp>,
2942  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2943  PointwiseConverter<tosa::ClzOp>,
2944  PointwiseConverter<tosa::SelectOp>,
2945  PointwiseConverter<tosa::GreaterOp>,
2946  PointwiseConverter<tosa::GreaterEqualOp>,
2947  PointwiseConverter<tosa::EqualOp>,
2948  PointwiseConverter<tosa::MaximumOp>,
2949  PointwiseConverter<tosa::MinimumOp>,
2950  PointwiseConverter<tosa::CeilOp>,
2951  PointwiseConverter<tosa::FloorOp>,
2952  PointwiseConverter<tosa::ClampOp>,
2953  PointwiseConverter<tosa::SigmoidOp>
2954  >(converter, patterns->getContext());
2955 
2956  patterns->add<
2957  IdentityNConverter<tosa::IdentityOp>,
2958  ReduceConverter<tosa::ReduceAllOp>,
2959  ReduceConverter<tosa::ReduceAnyOp>,
2960  ReduceConverter<tosa::ReduceMinOp>,
2961  ReduceConverter<tosa::ReduceMaxOp>,
2962  ReduceConverter<tosa::ReduceSumOp>,
2963  ReduceConverter<tosa::ReduceProductOp>,
2964  ArgMaxConverter,
2965  GatherConverter,
2966  RescaleConverter,
2967  ReverseConverter,
2968  RFFT2dConverter,
2969  FFT2dConverter,
2970  TableConverter,
2971  TileConverter>(patterns->getContext());
2972  // clang-format on
2973 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static bool operandsAndResultsRanked(Operation *operation)
const float * table
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:199
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:386
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:253
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:371
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:99
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:363
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:61
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322