MLIR  22.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/OpDefinition.h"
27 #include "mlir/IR/PatternMatch.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/Sequence.h"
31 
32 #include <type_traits>
33 
34 using namespace mlir;
35 using namespace mlir::tosa;
36 
37 // Helper function to materialize the semantically correct compare and select
38 // operations given a binary operation with a specific NaN propagation mode.
39 //
40 // In the case of "PROPAGATE" semantics no compare and selection is required and
41 // this function does nothing.
42 //
43 // In the case of "IGNORE" semantics this function materializes a comparison of
44 // the current operands to the op which will return true for any NaN
45 // argument and then selects between the non-NaN operation argument and the
46 // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
47 // code:
48 //
49 // In the case that the op is operating on non floating point types we ignore
50 // the attribute completely, this is consistent with the TOSA spec which has
51 // the following wording: "This attribute is ignored by non floating-point
52 // types."
53 //
54 // binary<op>(lhs, rhs):
55 // result = op(lhs, rhs)
56 // if lhs == NaN return rhs
57 // if rhs == NaN return lhs
58 // return result
59 template <typename OpTy>
60 static Value
62  Value lhs, Value rhs, Value result) {
63  // NaN propagation has no meaning for non floating point types.
64  if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
65  return result;
66 
67  auto nanMode = op.getNanMode();
68  if (nanMode == NanPropagationMode::PROPAGATE)
69  return result;
70 
71  // Unordered comparison of NaN against itself will always return true.
72  Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
73  arith::CmpFPredicate::UNO, lhs, lhs);
74  Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
75  arith::CmpFPredicate::UNO, rhs, rhs);
76  Value rhsOrResult =
77  arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
78  return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
79  rhsOrResult);
80 }
81 
83  Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
84  ConversionPatternRewriter &rewriter) {
85  Location loc = op->getLoc();
86  auto elementTy =
87  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
88 
89  // tosa::AbsOp
90  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
91  return math::AbsFOp::create(rewriter, loc, resultTypes, args);
92 
93  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
94  auto zero = arith::ConstantOp::create(rewriter, loc,
95  rewriter.getZeroAttr(elementTy));
96  auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
97  return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
98  }
99 
100  // tosa::AddOp
101  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
102  return arith::AddFOp::create(rewriter, loc, resultTypes, args);
103 
104  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
105  return arith::AddIOp::create(rewriter, loc, resultTypes, args);
106 
107  // tosa::SubOp
108  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
109  return arith::SubFOp::create(rewriter, loc, resultTypes, args);
110 
111  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
112  return arith::SubIOp::create(rewriter, loc, resultTypes, args);
113 
114  // tosa::IntDivOp
115  if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
116  return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
117 
118  // tosa::ReciprocalOp
119  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
120  auto one =
121  arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
122  return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
123  }
124 
125  // tosa::MulOp
126  if (isa<tosa::MulOp>(op)) {
127  auto shiftVal = cast<tosa::MulOp>(op).getShift();
128  DenseElementsAttr shiftElem;
129  bool shiftIsConstant = true;
130  int32_t shift = 0;
131  if (matchPattern(shiftVal, m_Constant(&shiftElem)))
132  shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
133  else
134  shiftIsConstant = false;
135 
136  if (isa<FloatType>(elementTy)) {
137  if (shift != 0) {
138  (void)rewriter.notifyMatchFailure(op,
139  "Cannot have shift value for float");
140  return nullptr;
141  }
142  return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
143  args[1]);
144  }
145 
146  if (isa<IntegerType>(elementTy)) {
147  Value a = args[0];
148  Value b = args[1];
149 
150  if (shift > 0 || !shiftIsConstant) {
151  Value shiftConst;
152  if (shiftIsConstant)
153  shiftConst = arith::ConstantIntOp::create(rewriter, loc, shift,
154  /*bitwidth=*/8);
155 
156  if (!a.getType().isInteger(32))
157  a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
158 
159  if (!b.getType().isInteger(32))
160  b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
161 
162  auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
163  auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
164  RoundingMode::SINGLE_ROUND);
165  auto result =
166  tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
167  b, shiftAmount, roundingAttr);
168 
169  return result;
170  }
171 
172  int aWidth = a.getType().getIntOrFloatBitWidth();
173  int bWidth = b.getType().getIntOrFloatBitWidth();
174  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
175 
176  if (aWidth < cWidth)
177  a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
178  if (bWidth < cWidth)
179  b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
180 
181  return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
182  }
183  }
184 
185  // tosa::NegateOp
186  if (isa<tosa::NegateOp>(op)) {
187  auto negate = cast<tosa::NegateOp>(op);
188 
189  int64_t inZp = 0, outZp = 0;
190  FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
191  FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
192  bool hasInZp = !failed(maybeInZp);
193  bool hasOutZp = !failed(maybeOutZp);
194  if (hasInZp)
195  inZp = *maybeInZp;
196  if (hasOutZp)
197  outZp = *maybeOutZp;
198 
199  if (isa<FloatType>(elementTy))
200  return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
201 
202  if (isa<IntegerType>(elementTy)) {
203  if (hasInZp && hasOutZp && !inZp && !outZp) {
204  auto constant = arith::ConstantOp::create(
205  rewriter, loc, IntegerAttr::get(elementTy, 0));
206  return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
207  args[0]);
208  }
209 
210  Value zpAddValue;
211  Type intermediateType;
212  // Compute the maximum value that can occur in the intermediate buffer.
213  const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
214  int intermediateBitWidth = 64;
215 
216  if (hasInZp && hasOutZp) {
217  // Compute the maximum value that can occur in the intermediate buffer.
218  const int64_t zpAdd = inZp + outZp;
219  const int64_t maxValue =
220  APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
221  std::abs(zpAdd) + 1;
222 
223  // Convert that maximum value into the maximum bitwidth needed to
224  // represent it. We assume 48-bit numbers may be supported further in
225  // the pipeline.
226  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
227  intermediateBitWidth = 16;
228  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
229  intermediateBitWidth = 32;
230  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
231  intermediateBitWidth = 48;
232  }
233 
234  intermediateType = rewriter.getIntegerType(intermediateBitWidth);
235  zpAddValue = arith::ConstantOp::create(
236  rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
237  } else {
238  intermediateType = rewriter.getIntegerType(intermediateBitWidth);
239  auto arg1 =
240  arith::ExtSIOp::create(rewriter, loc, intermediateType, args[1]);
241  auto arg2 =
242  arith::ExtSIOp::create(rewriter, loc, intermediateType, args[2]);
243  zpAddValue =
244  arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
245  }
246 
247  // The negation can be applied by doing:
248  // outputValue = inZp + outZp - inputValue
249  auto ext =
250  arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]);
251  auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
252 
253  // Clamp to the negation range.
255  rewriter, loc, intermediateType,
256  APInt::getSignedMinValue(inputBitWidth).getSExtValue());
258  rewriter, loc, intermediateType,
259  APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
260  auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
261 
262  // Truncate to the final value.
263  return arith::TruncIOp::create(rewriter, loc, elementTy, clamp);
264  }
265  }
266 
267  // tosa::BitwiseAndOp
268  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
269  return arith::AndIOp::create(rewriter, loc, resultTypes, args);
270 
271  // tosa::BitwiseOrOp
272  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
273  return arith::OrIOp::create(rewriter, loc, resultTypes, args);
274 
275  // tosa::BitwiseNotOp
276  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
277  auto allOnesAttr = rewriter.getIntegerAttr(
278  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
279  auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
280  return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
281  }
282 
283  // tosa::BitwiseXOrOp
284  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
285  return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
286 
287  // tosa::LogicalLeftShiftOp
288  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
289  return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
290 
291  // tosa::LogicalRightShiftOp
292  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
293  return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
294 
295  // tosa::ArithmeticRightShiftOp
296  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
297  auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
298  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
299  if (!round) {
300  return result;
301  }
302 
303  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
304  auto one = arith::ConstantOp::create(rewriter, loc,
305  IntegerAttr::get(elementTy, 1));
306  auto zero = arith::ConstantOp::create(rewriter, loc,
307  IntegerAttr::get(elementTy, 0));
308  auto i1zero =
309  arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
310  auto i1one =
311  arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
312 
313  // Checking that input2 != 0
314  auto shiftValueGreaterThanZero = arith::CmpIOp::create(
315  rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
316 
317  // Checking for the last bit of input1 to be 1
318  auto subtract =
319  arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
320  auto shifted =
321  arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
322  ->getResults();
323  auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
325  auto isInputOdd =
326  arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
327  // shifted, truncated, isInputOdd can be poison when input2 is 0.
328  auto shouldRound = arith::SelectOp::create(
329  rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
330  auto extended =
331  arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
332  return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
333  }
334 
335  // tosa::ClzOp
336  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
337  return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
338  }
339 
340  // tosa::LogicalAnd
341  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
342  return arith::AndIOp::create(rewriter, loc, resultTypes, args);
343 
344  // tosa::LogicalNot
345  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
346  auto one = arith::ConstantOp::create(rewriter, loc,
347  rewriter.getIntegerAttr(elementTy, 1));
348  return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
349  }
350 
351  // tosa::LogicalOr
352  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
353  return arith::OrIOp::create(rewriter, loc, resultTypes, args);
354 
355  // tosa::LogicalXor
356  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
357  return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
358 
359  // tosa::PowOp
360  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
361  return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
362 
363  // tosa::RsqrtOp
364  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
365  return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
366 
367  // tosa::LogOp
368  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
369  return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
370 
371  // tosa::ExpOp
372  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
373  return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
374 
375  // tosa::SinOp
376  if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
377  return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
378 
379  // tosa::CosOp
380  if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
381  return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
382 
383  // tosa::TanhOp
384  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
385  return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
386 
387  // tosa::ErfOp
388  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
389  return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
390 
391  // tosa::GreaterOp
392  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
393  return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
394  args[0], args[1]);
395 
396  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
397  return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
398  args[0], args[1]);
399 
400  // tosa::GreaterEqualOp
401  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
402  return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
403  args[0], args[1]);
404 
405  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
406  return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
407  args[0], args[1]);
408 
409  // tosa::EqualOp
410  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
411  return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
412  args[0], args[1]);
413 
414  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
415  return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
416  args[0], args[1]);
417 
418  // tosa::SelectOp
419  if (isa<tosa::SelectOp>(op)) {
420  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
421  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
422  return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
423  }
424 
425  // tosa::MaximumOp
426  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
427  auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
428  return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
429  rewriter, args[0], args[1], max);
430  }
431 
432  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
433  return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
434  }
435 
436  // tosa::MinimumOp
437  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
438  auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
439  return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
440  rewriter, args[0], args[1], min);
441  }
442 
443  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
444  return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
445  }
446 
447  // tosa::CeilOp
448  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
449  return math::CeilOp::create(rewriter, loc, resultTypes, args);
450 
451  // tosa::FloorOp
452  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
453  return math::FloorOp::create(rewriter, loc, resultTypes, args);
454 
455  // tosa::ClampOp
456  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
457  bool losesInfo = false;
458  APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
459  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
460  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
461  APFloat::rmNearestTiesToEven, &losesInfo);
462  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
463  APFloat::rmNearestTiesToEven, &losesInfo);
464  auto min = arith::ConstantOp::create(
465  rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
466  auto max = arith::ConstantOp::create(
467  rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
468  auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
469 
470  auto clampOp = llvm::cast<tosa::ClampOp>(op);
471  const auto nanMode = clampOp.getNanMode();
472 
473  // NaN propagation has no meaning for non floating point types.
474  if (!isa<FloatType>(elementTy))
475  return result;
476 
477  // In the case of "PROPAGATE" semantics no compare and selection is
478  // required.
479  if (nanMode == NanPropagationMode::PROPAGATE)
480  return result;
481 
482  // In the case of "IGNORE" semantics materialize a comparison
483  // of the current operand to the reduction which will return true for a NaN
484  // argument and then selects between the initial reduction value and the
485  // calculated result based on whether the argument is NaN or not. In pseudo
486  // code:
487  //
488  // reduce<op>(x, init):
489  // result = op(init, x)
490  // return init if x == NaN else result
491 
492  // Unordered comparison of NaN against itself will always return true.
493  Value isNaN = arith::CmpFOp::create(
494  rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
495  // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
496  // is NaN.
497  return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result);
498  }
499 
500  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
501  auto intTy = cast<IntegerType>(elementTy);
502  int64_t min =
503  cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
504  int64_t max =
505  cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
506 
507  int64_t minRepresentable = std::numeric_limits<int64_t>::min();
508  int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
509  if (intTy.isUnsignedInteger()) {
510  minRepresentable = 0;
511  if (intTy.getIntOrFloatBitWidth() <= 63) {
512  maxRepresentable =
513  (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
514  .getZExtValue();
515  }
516  } else if (intTy.getIntOrFloatBitWidth() <= 64) {
517  // Ensure that min & max fit into signed n-bit constants.
518  minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
519  .getSExtValue();
520  maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
521  .getSExtValue();
522  }
523  // Ensure that the bounds are representable as n-bit signed/unsigned
524  // integers.
525  min = std::max(min, minRepresentable);
526  max = std::max(max, minRepresentable);
527  min = std::min(min, maxRepresentable);
528  max = std::min(max, maxRepresentable);
529 
530  auto minVal = arith::ConstantIntOp::create(rewriter, loc, min,
531  intTy.getIntOrFloatBitWidth());
532  auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max,
533  intTy.getIntOrFloatBitWidth());
534  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
535  intTy.isUnsignedInteger());
536  }
537 
538  // tosa::SigmoidOp
539  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
540  auto one =
541  arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
542  auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
543  auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
544  auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
545  return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
546  }
547 
548  // tosa::CastOp
549  if (isa<tosa::CastOp>(op)) {
550  Type srcTy = elementTy;
551  Type dstTy = resultTypes.front();
552  if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
553  (void)rewriter.notifyMatchFailure(op, "unsupported type");
554  return nullptr;
555  }
556 
557  bool bitExtend =
559 
560  if (srcTy == dstTy)
561  return args.front();
562 
563  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
564  return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
566 
567  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
568  return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
570 
571  // 1-bit integers need to be treated as signless.
572  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
573  return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
575 
576  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
577  return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
579 
580  // Unsigned integers need an unrealized cast so that they can be passed
581  // to UIToFP.
582  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
583  auto unrealizedCast =
584  UnrealizedConversionCastOp::create(
585  rewriter, loc,
586  rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
587  .getResult(0);
588  return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
589  unrealizedCast);
590  }
591 
592  // All other si-to-fp conversions should be handled by SIToFP.
593  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
594  return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
596 
597  // Casting to boolean, floats need to only be checked as not-equal to zero.
598  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
599  Value zero = arith::ConstantOp::create(rewriter, loc,
600  rewriter.getFloatAttr(srcTy, 0.0));
601  return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
602  args.front(), zero);
603  }
604 
605  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
606  auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
607 
608  const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
609  // Check whether neither int min nor int max can be represented in the
610  // input floating-point type due to too short exponent range.
611  if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
612  APFloat::semanticsMaxExponent(fltSemantics)) {
613  // Use cmp + select to replace infinites by int min / int max. Other
614  // integral values can be represented in the integer space.
615  auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
616  auto posInf = arith::ConstantOp::create(
617  rewriter, loc,
618  rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
619  APFloat::getInf(fltSemantics)));
620  auto negInf = arith::ConstantOp::create(
621  rewriter, loc,
622  rewriter.getFloatAttr(
623  getElementTypeOrSelf(srcTy),
624  APFloat::getInf(fltSemantics, /*Negative=*/true)));
625  auto overflow = arith::CmpFOp::create(
626  rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
627  auto underflow = arith::CmpFOp::create(
628  rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
629  auto intMin = arith::ConstantOp::create(
630  rewriter, loc,
631  rewriter.getIntegerAttr(
632  getElementTypeOrSelf(dstTy),
633  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
634  auto intMax = arith::ConstantOp::create(
635  rewriter, loc,
636  rewriter.getIntegerAttr(
637  getElementTypeOrSelf(dstTy),
638  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
639  auto maxClamped =
640  arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
641  return arith::SelectOp::create(rewriter, loc, underflow, intMin,
642  maxClamped);
643  }
644 
645  auto intMinFP = arith::ConstantOp::create(
646  rewriter, loc,
647  rewriter.getFloatAttr(
648  getElementTypeOrSelf(srcTy),
649  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
650  .getSExtValue()));
651 
652  // Check whether the mantissa has enough bits to represent int max.
653  if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
654  dstTy.getIntOrFloatBitWidth() - 1) {
655  // Int min can also be represented since it is a power of two and thus
656  // consists of a single leading bit. Therefore we can clamp the input
657  // in the floating-point domain.
658 
659  auto intMaxFP = arith::ConstantOp::create(
660  rewriter, loc,
661  rewriter.getFloatAttr(
662  getElementTypeOrSelf(srcTy),
663  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
664  .getSExtValue()));
665 
666  Value clamped =
667  clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
668  return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
669  }
670 
671  // Due to earlier check we know exponant range is big enough to represent
672  // int min. We can therefore rely on int max + 1 being representable as
673  // well because it's just int min with a positive sign. So clamp the min
674  // value and compare against that to select the max int value if needed.
675  auto intMaxPlusOneFP = arith::ConstantOp::create(
676  rewriter, loc,
677  rewriter.getFloatAttr(
678  getElementTypeOrSelf(srcTy),
679  static_cast<double>(
680  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
681  .getSExtValue()) +
682  1.0f));
683 
684  auto intMax = arith::ConstantOp::create(
685  rewriter, loc,
686  rewriter.getIntegerAttr(
687  getElementTypeOrSelf(dstTy),
688  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
689  auto minClampedFP =
690  arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
691  auto minClamped =
692  arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
693  auto overflow = arith::CmpFOp::create(
694  rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
695  return arith::SelectOp::create(rewriter, loc, overflow, intMax,
696  minClamped);
697  }
698 
699  // Casting to boolean, integers need to only be checked as not-equal to
700  // zero.
701  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
702  Value zero = arith::ConstantIntOp::create(rewriter, loc, 0,
703  srcTy.getIntOrFloatBitWidth());
704  return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
705  args.front(), zero);
706  }
707 
708  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
709  return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
711 
712  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
713  return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
714  }
715  }
716 
717  (void)rewriter.notifyMatchFailure(
718  op, "unhandled op for linalg body calculation for elementwise op");
719  return nullptr;
720 }
721 
723 
724 // Emit an 'arith.constant' op for the given index if it has not been created
725 // yet, or return an existing constant. This will prevent an excessive creation
726 // of redundant constants, easing readability of emitted code for unit tests.
728  IndexPool &indexPool, int64_t index) {
729  auto [it, inserted] = indexPool.try_emplace(index);
730  if (inserted)
731  it->second =
732  arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index));
733  return it->second;
734 }
735 
737  IndexPool &indexPool, Value tensor, int64_t index) {
738  auto indexValue = createIndex(rewriter, loc, indexPool, index);
739  return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
740 }
741 
743  IndexPool &indexPool, Value tensor,
744  int64_t index) {
745  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
746  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
747  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
748  if (shapedType.isDynamicDim(index))
749  return getTensorDim(rewriter, loc, indexPool, tensor, index);
750  return rewriter.getIndexAttr(shapedType.getDimSize(index));
751 }
752 
753 static bool operandsAndResultsRanked(Operation *operation) {
754  auto isRanked = [](Value value) {
755  return isa<RankedTensorType>(value.getType());
756  };
757  return llvm::all_of(operation->getOperands(), isRanked) &&
758  llvm::all_of(operation->getResults(), isRanked);
759 }
760 
761 // Compute the runtime dimension size for dimension 'dim' of the output by
762 // inspecting input 'operands', all of which are expected to have the same rank.
763 // This function returns a pair {targetSize, masterOperand}.
764 //
765 // The runtime size of the output dimension is returned either as a statically
766 // computed attribute or as a runtime SSA value.
767 //
768 // If the target size was inferred directly from one dominating operand, that
769 // operand is returned in 'masterOperand'. If the target size is inferred from
770 // multiple operands, 'masterOperand' is set to nullptr.
771 static std::pair<OpFoldResult, Value>
773  ValueRange operands, int64_t dim) {
774  // If any input operand contains a static size greater than 1 for this
775  // dimension, that is the target size. An occurrence of an additional static
776  // dimension greater than 1 with a different value is undefined behavior.
777  for (auto operand : operands) {
778  auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
779  if (ShapedType::isStatic(size) && size > 1)
780  return {rewriter.getIndexAttr(size), operand};
781  }
782 
783  // Filter operands with dynamic dimension
784  auto operandsWithDynamicDim =
785  llvm::filter_to_vector(operands, [&](Value operand) {
786  return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
787  });
788 
789  // If no operand has a dynamic dimension, it means all sizes were 1
790  if (operandsWithDynamicDim.empty())
791  return {rewriter.getIndexAttr(1), operands.front()};
792 
793  // Emit code that computes the runtime size for this dimension. If there is
794  // only one operand with a dynamic dimension, it is considered the master
795  // operand that determines the runtime size of the output dimension.
796  auto targetSize =
797  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
798  if (operandsWithDynamicDim.size() == 1)
799  return {targetSize, operandsWithDynamicDim[0]};
800 
801  // Calculate maximum size among all dynamic dimensions
802  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
803  auto nextSize =
804  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
805  targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
806  }
807  return {targetSize, nullptr};
808 }
809 
810 // Compute the runtime output size for all dimensions. This function returns
811 // a pair {targetShape, masterOperands}.
812 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
814  IndexPool &indexPool, ValueRange operands) {
815  assert(!operands.empty());
816  auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
817  SmallVector<OpFoldResult> targetShape;
818  SmallVector<Value> masterOperands;
819  for (auto dim : llvm::seq<int64_t>(0, rank)) {
820  auto [targetSize, masterOperand] =
821  computeTargetSize(rewriter, loc, indexPool, operands, dim);
822  targetShape.push_back(targetSize);
823  masterOperands.push_back(masterOperand);
824  }
825  return {targetShape, masterOperands};
826 }
827 
829  IndexPool &indexPool, Value operand,
830  int64_t dim, OpFoldResult targetSize,
831  Value masterOperand) {
832  // Nothing to do if this is a static dimension
833  auto rankedTensorType = cast<RankedTensorType>(operand.getType());
834  if (!rankedTensorType.isDynamicDim(dim))
835  return operand;
836 
837  // If the target size for this dimension was directly inferred by only taking
838  // this operand into account, there is no need to broadcast. This is an
839  // optimization that will prevent redundant control flow, and constitutes the
840  // main motivation for tracking "master operands".
841  if (operand == masterOperand)
842  return operand;
843 
844  // Affine maps for 'linalg.generic' op
845  auto rank = rankedTensorType.getRank();
846  SmallVector<AffineExpr> affineExprs;
847  for (auto index : llvm::seq<int64_t>(0, rank)) {
848  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
849  : rewriter.getAffineDimExpr(index);
850  affineExprs.push_back(affineExpr);
851  }
852  auto broadcastAffineMap =
853  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
854  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
855  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
856 
857  // Check if broadcast is necessary
858  auto one = createIndex(rewriter, loc, indexPool, 1);
859  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
860  auto broadcastNecessary = arith::CmpIOp::create(
861  rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
862 
863  // Emit 'then' region of 'scf.if'
864  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
865  // It is not safe to cache constants across regions.
866  // New constants could potentially violate dominance requirements.
867  IndexPool localPool;
868 
869  // Emit 'tensor.empty' op
870  SmallVector<OpFoldResult> outputTensorShape;
871  for (auto index : llvm::seq<int64_t>(0, rank)) {
872  auto size = index == dim ? targetSize
873  : getOrFoldTensorDim(rewriter, loc, localPool,
874  operand, index);
875  outputTensorShape.push_back(size);
876  }
877  Value outputTensor = tensor::EmptyOp::create(
878  opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
879 
880  // Emit 'linalg.generic' op
881  auto resultTensor =
882  linalg::GenericOp::create(
883  opBuilder, loc, outputTensor.getType(), operand, outputTensor,
884  affineMaps, getNParallelLoopsAttrs(rank),
885  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
886  // Emit 'linalg.yield' op
887  linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
888  })
889  .getResult(0);
890 
891  // Cast to original operand type if necessary
892  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
893  loc, operand.getType(), resultTensor);
894 
895  // Emit 'scf.yield' op
896  scf::YieldOp::create(opBuilder, loc, castResultTensor);
897  };
898 
899  // Emit 'else' region of 'scf.if'
900  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
901  scf::YieldOp::create(opBuilder, loc, operand);
902  };
903 
904  // Emit 'scf.if' op
905  auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
906  emitThenRegion, emitElseRegion);
907  return ifOp.getResult(0);
908 }
909 
911  IndexPool &indexPool, Value operand,
912  ArrayRef<OpFoldResult> targetShape,
913  ArrayRef<Value> masterOperands) {
914  int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
915  assert((int64_t)targetShape.size() == rank);
916  assert((int64_t)masterOperands.size() == rank);
917  for (auto index : llvm::seq<int64_t>(0, rank))
918  operand =
919  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
920  targetShape[index], masterOperands[index]);
921  return operand;
922 }
923 
924 static SmallVector<Value>
926  IndexPool &indexPool, ValueRange operands,
927  ArrayRef<OpFoldResult> targetShape,
928  ArrayRef<Value> masterOperands) {
929  // No need to broadcast for unary operations
930  if (operands.size() == 1)
931  return operands;
932 
933  // No need to broadcast for static shape
934  bool hasDynamic = false;
935  for (auto op : operands) {
936  const auto tType = dyn_cast<RankedTensorType>(op.getType());
937  if (tType && !tType.hasStaticShape()) {
938  hasDynamic = true;
939  break;
940  }
941  }
942  if (!hasDynamic)
943  return operands;
944 
945  // Broadcast dynamic dimensions operand by operand
946  return llvm::map_to_vector(operands, [&](Value operand) {
947  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
948  targetShape, masterOperands);
949  });
950 }
951 
952 static LogicalResult
954  Operation *operation, ValueRange operands,
955  ArrayRef<OpFoldResult> targetShape,
956  const TypeConverter &converter) {
957  // Generate output tensor
958  auto resultType = cast_or_null<RankedTensorType>(
959  converter.convertType(operation->getResultTypes().front()));
960  if (!resultType) {
961  return rewriter.notifyMatchFailure(operation, "failed to convert type");
962  }
963  Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
964  resultType.getElementType());
965 
966  // Create affine maps. Input affine maps broadcast static dimensions of size
967  // 1. The output affine map is an identity map.
968  //
969  auto rank = resultType.getRank();
970  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
971  auto shape = cast<ShapedType>(operand.getType()).getShape();
972  SmallVector<AffineExpr> affineExprs;
973  for (auto it : llvm::enumerate(shape)) {
974  // Prefer producting identity maps whenever possible (i.e. no broadcasting
975  // needed) because some transforms (like reshape folding)
976  // do not support affine constant exprs.
977  bool requiresBroadcast =
978  (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
979  auto affineExpr = requiresBroadcast
980  ? rewriter.getAffineConstantExpr(0)
981  : rewriter.getAffineDimExpr(it.index());
982  affineExprs.push_back(affineExpr);
983  }
984  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
985  });
986  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
987 
988  // Emit 'linalg.generic' op
989  bool encounteredError = false;
990  auto linalgOp = linalg::GenericOp::create(
991  rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps,
993  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
995  operation, blockArgs.take_front(operation->getNumOperands()),
996  {resultType.getElementType()}, rewriter);
997  if (!opResult) {
998  encounteredError = true;
999  return;
1000  }
1001  linalg::YieldOp::create(opBuilder, loc, opResult);
1002  });
1003  if (encounteredError)
1004  return rewriter.notifyMatchFailure(
1005  operation, "unable to create linalg.generic body for elementwise op");
1006 
1007  // Cast 'linalg.generic' result into original result type if needed
1008  auto castResult = rewriter.createOrFold<tensor::CastOp>(
1009  loc, resultType, linalgOp->getResult(0));
1010  rewriter.replaceOp(operation, castResult);
1011  return success();
1012 }
1013 
1015  ValueRange operands) {
1016  // Shift cannot broadcast
1017  if (isa<tosa::MulOp>(operation)) {
1018  DenseElementsAttr shiftElems;
1019  // Shift cannot broadcast when it is constant
1020  if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
1021  return operands.take_front(2);
1022  else
1023  return operands.take_front(3);
1024  }
1025  if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1026  FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1027  FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1028  if (failed(maybeOutZp) && failed(maybeInZp))
1029  return operands;
1030  // Input1_zp and output_zp cannot broadcast when they are constants.
1031  return operands.take_front(1);
1032  }
1033  return operands;
1034 }
1035 
1036 static LogicalResult
1038  ConversionPatternRewriter &rewriter,
1039  const TypeConverter &converter) {
1040 
1041  // Collect op properties
1042  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1043  assert(operation->getNumOperands() >= 1 &&
1044  "elementwise op expects at least 1 operand");
1045  if (!operandsAndResultsRanked(operation))
1046  return rewriter.notifyMatchFailure(operation,
1047  "Unranked tensors not supported");
1048 
1049  // Lower operation
1050  IndexPool indexPool;
1051  auto loc = operation->getLoc();
1052  auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1053  auto [targetShape, masterOperands] =
1054  computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
1055  auto broadcastOperands =
1056  broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
1057  targetShape, masterOperands);
1058  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
1059  targetShape, converter);
1060 }
1061 
1062 // Returns the constant initial value for a given reduction operation. The
1063 // attribute type varies depending on the element type required.
1064 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1065  PatternRewriter &rewriter) {
1066  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1067  return rewriter.getFloatAttr(elementTy, 0.0);
1068 
1069  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1070  return rewriter.getIntegerAttr(elementTy, 0);
1071 
1072  if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1073  return rewriter.getFloatAttr(elementTy, 1.0);
1074 
1075  if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1076  return rewriter.getIntegerAttr(elementTy, 1);
1077 
1078  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1079  return rewriter.getFloatAttr(
1080  elementTy, APFloat::getLargest(
1081  cast<FloatType>(elementTy).getFloatSemantics(), false));
1082 
1083  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1084  return rewriter.getIntegerAttr(
1085  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
1086 
1087  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1088  return rewriter.getFloatAttr(
1089  elementTy, APFloat::getLargest(
1090  cast<FloatType>(elementTy).getFloatSemantics(), true));
1091 
1092  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1093  return rewriter.getIntegerAttr(
1094  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1095 
1096  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1097  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1098 
1099  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1100  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1101 
1102  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1103  return rewriter.getFloatAttr(
1104  elementTy, APFloat::getLargest(
1105  cast<FloatType>(elementTy).getFloatSemantics(), true));
1106 
1107  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1108  return rewriter.getIntegerAttr(
1109  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1110 
1111  return {};
1112 }
1113 
1114 // Creates the body calculation for a reduction. The operations vary depending
1115 // on the input type.
1117  ValueRange args,
1118  Type elementTy,
1119  PatternRewriter &rewriter) {
1120  Location loc = op->getLoc();
1121  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1122  return arith::AddFOp::create(rewriter, loc, args);
1123  }
1124 
1125  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1126  return arith::AddIOp::create(rewriter, loc, args);
1127  }
1128 
1129  if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1130  return arith::MulFOp::create(rewriter, loc, args);
1131  }
1132 
1133  if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1134  return arith::MulIOp::create(rewriter, loc, args);
1135  }
1136 
1137  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1138  return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1139  }
1140 
1141  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1142  return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1143  }
1144 
1145  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1146  return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1147  }
1148 
1149  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1150  return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1151  }
1152 
1153  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1154  return arith::AndIOp::create(rewriter, loc, args);
1155 
1156  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1157  return arith::OrIOp::create(rewriter, loc, args);
1158 
1159  return {};
1160 }
1161 
1162 // Performs the match and rewrite for reduction operations. This includes
1163 // declaring a correctly sized initial value, and the linalg.generic operation
1164 // that reduces across the specified axis.
1165 template <typename OpTy>
1166 static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1167  PatternRewriter &rewriter) {
1168  auto loc = op->getLoc();
1169  auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1170  auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1171  if (!inputTy || !resultTy)
1172  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1173 
1174  auto elementTy = resultTy.getElementType();
1175  Value input = op->getOperand(0);
1176 
1177  // Figure out the accType if needed
1178  bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1179  isa<FloatType>(elementTy) &&
1180  cast<FloatType>(elementTy).isBF16();
1181  Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
1182 
1183  SmallVector<int64_t> reduceShape;
1184  SmallVector<Value> dynDims;
1185  for (unsigned i = 0; i < inputTy.getRank(); i++) {
1186  if (axis != i) {
1187  reduceShape.push_back(inputTy.getDimSize(i));
1188  if (inputTy.isDynamicDim(i))
1189  dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1190  }
1191  }
1192 
1193  SmallVector<Value> inputs, outputs;
1194  inputs.push_back(input);
1195 
1196  // First fill the output buffer with the init value.
1197  auto emptyTensor =
1198  tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1199  .getResult();
1200 
1201  auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
1202  if (!fillValueAttr)
1203  return rewriter.notifyMatchFailure(
1204  op, "No initial value found for reduction operation");
1205 
1206  auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1207  auto filledTensor =
1208  linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
1209  ValueRange{emptyTensor})
1210  .result();
1211  outputs.push_back(filledTensor);
1212 
1213  bool isNanIgnoreMode = false;
1214  if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1215  std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1216  // NaN propagation has no meaning for non floating point types.
1217  if (isa<FloatType>(elementTy) &&
1218  op.getNanMode() == NanPropagationMode::IGNORE) {
1219  isNanIgnoreMode = true;
1220  // Because the TOSA spec requires the result be NaN iff all elements in
1221  // the reduction are NaN we can't simply perform a compare and select.
1222  // Additionally we have to keep track of whether we've seen any non-NaN
1223  // values and then do a final select based on this predicate.
1224  auto trueAttr = rewriter.getBoolAttr(true);
1225  auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1226  auto emptyBoolTensor =
1227  tensor::EmptyOp::create(rewriter, loc, reduceShape,
1228  trueValue.getType(), dynDims)
1229  .getResult();
1230  auto allResultsNaNTensor =
1231  linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
1232  ValueRange{emptyBoolTensor})
1233  .result();
1234  // Note that because the linalg::ReduceOp has two variadic arguments
1235  // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1236  // need to have the same number of inputs and outputs.
1237  //
1238  // The second input isn't actually used anywhere since the value used to
1239  // update the NaN flag is calculated inside the body of the reduction and
1240  // then used to update an out value.
1241  // In order to satisfy type constraints we just pass another copy of the
1242  // input here.
1243  inputs.push_back(input);
1244  outputs.push_back(allResultsNaNTensor);
1245  }
1246  }
1247 
1248  bool didEncounterError = false;
1249  linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1250  rewriter, loc, inputs, outputs, axis,
1251  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1252  std::array<Value, 2> binaryArgs{
1253  blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1254 
1255  // If reduction type differs then extend (applicable to reduce_sum)
1256  if (binaryArgs[0].getType() != accTy)
1257  binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1258  binaryArgs[0]);
1259 
1260  auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
1261  accTy, rewriter);
1262  if (result)
1263  didEncounterError = true;
1264 
1265  SmallVector<Value> resultsToYield;
1266  if (isNanIgnoreMode) {
1267  auto inputValue = blockArgs[0];
1268  auto initialValue = blockArgs[2];
1269  auto oldAllResultsNanFlagValue = blockArgs[3];
1270 
1271  // Unordered comparison of NaN against itself will always return true.
1272  Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1273  arith::CmpFPredicate::UNO,
1274  inputValue, inputValue);
1275  // If we've encountered a NaN, take the non-NaN value.
1276  auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1277  isNaN, initialValue, result);
1278  // Update the flag which keeps track of whether we have seen a non-NaN
1279  // value.
1280  auto newAllResultsNanFlagValue = arith::AndIOp::create(
1281  nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1282  resultsToYield.push_back(selectOp);
1283  resultsToYield.push_back(newAllResultsNanFlagValue);
1284  } else {
1285  resultsToYield.push_back(result);
1286  }
1287  linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1288  });
1289 
1290  if (!didEncounterError)
1291  return rewriter.notifyMatchFailure(
1292  op, "unable to create linalg.generic body for reduce op");
1293 
1294  if (isNanIgnoreMode) {
1295  // Materialize a check to see whether we encountered any non-NaN values, if
1296  // we didn't we need to select a tensor of NaNs since the result will just
1297  // be the initial identity value propagated through all the compares and
1298  // selects inside the reduction.
1299 
1300  // Create a tensor full of NaNs.
1301  auto nanValueAttr = rewriter.getFloatAttr(
1302  accTy,
1303  APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1304  auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1305  auto emptyNanTensor =
1306  tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1307  .getResult();
1308  auto nanFilledTensor =
1309  linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
1310  ValueRange{emptyNanTensor})
1311  .result();
1312 
1313  // Create an empty tensor, non need to fill this since it will be
1314  // overwritten by the select.
1315  auto finalEmptyTensor =
1316  tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1317  .getResult();
1318 
1319  // Do a selection between the tensors akin to:
1320  // result = NaN if "all results NaN" else result.
1321  SmallVector<Value> ins, outs;
1322  ins.push_back(linalgOp->getOpResult(1));
1323  ins.push_back(nanFilledTensor);
1324  ins.push_back(linalgOp->getResult(0));
1325  outs.push_back(finalEmptyTensor);
1326  auto linalgSelect =
1327  linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1328  linalgOp = linalgSelect;
1329  }
1330 
1331  // Truncate back to resultTy if needed
1332  Value reducedRes = linalgOp->getResult(0);
1333  if (widenAccTy) {
1334  auto resEmptyOp =
1335  tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1336  .getResult();
1337 
1338  const unsigned reducedRank =
1339  cast<ShapedType>(reducedRes.getType()).getRank();
1340  auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
1341  reducedRes =
1342  linalg::GenericOp::create(
1343  rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes},
1344  ValueRange{resEmptyOp},
1345  ArrayRef<AffineMap>{identityMap, identityMap},
1346  getNParallelLoopsAttrs(reducedRank),
1347  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1348  Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1349  elementTy, args[0]);
1350  linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1351  })
1352  .getResults()[0];
1353  }
1354 
1355  SmallVector<ReassociationExprs, 4> reassociationMap;
1356  uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
1357  reassociationMap.resize(expandInputRank);
1358 
1359  for (uint64_t i = 0; i < expandInputRank; i++) {
1360  int32_t dimToPush = i > axis ? i + 1 : i;
1361  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1362  }
1363 
1364  if (expandInputRank != 0) {
1365  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1366  reassociationMap[expandedDim].push_back(
1367  rewriter.getAffineDimExpr(expandedDim + 1));
1368  }
1369 
1370  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1371  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1372  // not have access to such information. This matters when handling dynamically
1373  // sized tensors.
1374  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
1375  reassociationMap);
1376  return success();
1377 }
1378 
1379 namespace {
1380 
1381 template <typename SrcOp>
1382 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1383 public:
1386 
1387  LogicalResult
1388  matchAndRewrite(SrcOp op, OpAdaptor operands,
1389  ConversionPatternRewriter &rewriter) const final {
1391  op, operands.getOperands(), rewriter, *this->getTypeConverter());
1392  }
1393 };
1394 
1395 // Collapse tensor<1xiN> into tensor<iN>
1396 // E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1397 static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
1398  Location loc) {
1399  SmallVector<ReassociationExprs, 1> reassociation;
1400  // Create the collapsed type
1401  auto inputType = cast<RankedTensorType>(input.getType());
1402  auto elemType = inputType.getElementType();
1403  auto collapsedType = RankedTensorType::get({}, elemType);
1404  // Emit the collapse op
1405  return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1406  reassociation);
1407 }
1408 
1410 convertToI8(const llvm::SmallVector<int32_t> &input) {
1412  output.reserve(input.size());
1413 
1414  for (auto v : llvm::map_range(
1415  input, [](int32_t val) { return static_cast<int8_t>(val); })) {
1416  output.push_back(v);
1417  }
1418  return output;
1419 }
1420 
1421 // The shift or multiplier may be either constant or non-constant, depending on
1422 // whether dynamic extension is enabled.
1423 // - If the shift or multiplier is non-constant, add it as an input to
1424 // linalg::GenericOp by:
1425 // 1. Pushing it into 'genericInputs'.
1426 // 2. Appending a corresponding affine map to 'indexingMaps'.
1427 // - If the shift or multiplier is constant, set 'constant' instead.
1428 static void setupLinalgGenericOpInputAndIndexingMap(
1429  PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values,
1430  SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1431  bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
1432  bool isShift = false) {
1433 
1434  auto loc = op.getLoc();
1435  auto inputTy = cast<ShapedType>(op.getInput().getType());
1436  unsigned rank = inputTy.getRank();
1437  SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
1438 
1439  if (isConstant) {
1440  // If we are rescaling per-channel then we need to store the
1441  // values in a buffer.
1442  if (values.size() == 1) {
1443  IntegerAttr intAttr = isShift
1444  ? rewriter.getI8IntegerAttr(values.front())
1445  : rewriter.getI32IntegerAttr(values.front());
1446  constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1447  } else {
1448  auto elementType =
1449  isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
1450  auto tensorType = RankedTensorType::get(
1451  {static_cast<int64_t>(values.size())}, elementType);
1452  DenseIntElementsAttr EltAttr;
1453  if (isShift)
1454  EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
1455  else
1456  EltAttr = DenseIntElementsAttr::get(tensorType, values);
1457  genericInputs.push_back(
1458  arith::ConstantOp::create(rewriter, loc, EltAttr));
1459  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1460  /*symbolCount=*/0, exprs,
1461  rewriter.getContext()));
1462  }
1463  } else {
1464  // If we are not rescaling per-channel then we need to collapse 1xN to N
1465  // and push broadcastMap.
1466  auto operand = isShift ? op.getShift() : op.getMultiplier();
1467  auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1468  if (tensorType && tensorType.hasStaticShape() &&
1469  tensorType.getShape()[0] == 1) {
1470  // broadcastMap = affine_map<(d0, d1) -> ()>
1471  // It would affect as broadcast for scalar values in linalg::GenericOp.
1472  AffineMap broadcastMap =
1473  AffineMap::get(rank, 0, {}, rewriter.getContext());
1474  genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1475  indexingMaps.push_back(broadcastMap);
1476  } else {
1477  genericInputs.push_back(operand);
1478  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1479  /*symbolCount=*/0, exprs,
1480  rewriter.getContext()));
1481  }
1482  }
1483  arg = indexingMaps.size() - 1;
1484 }
1485 
1486 // Return the extended Zp to be used in subsequent arithmetic operations.
1487 static Value getExtendZp(OpBuilder &builder, Type valueTy,
1488  FailureOr<int64_t> maybeZp, Location loc,
1489  ValueRange blockArgs, int64_t zpArg,
1490  bool isOutputZp = false) {
1491  Value result;
1492  const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
1493  const uint32_t attrBitwidth =
1494  isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1495  auto extendType = builder.getIntegerType(attrBitwidth);
1496  // The Zp value can be either constant or non-constant, depending on
1497  // whether dynamic extension is enabled.
1498  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1499  // be passed as an input to linalg::GenericOp.
1500  if (failed(maybeZp)) {
1501  result = blockArgs[zpArg];
1502  auto zpTy = result.getType();
1503  if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1504  // For ExtUIOp, the input must be signless.
1505  // UnrealizedConversionCastOp will cast the input to signless type.
1506  if (zpTy.isUnsignedInteger()) {
1507  result =
1508  UnrealizedConversionCastOp::create(
1509  builder, loc,
1510  builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
1511  .getResult(0);
1512  }
1513  if (zpTy.isUnsignedInteger()) {
1514  return arith::ExtUIOp::create(builder, loc, extendType, result);
1515  } else {
1516  return arith::ExtSIOp::create(builder, loc, extendType, result);
1517  }
1518  }
1519  } else {
1520  return arith::ConstantOp::create(builder, loc,
1521  IntegerAttr::get(extendType, *maybeZp));
1522  }
1523  return result;
1524 }
1525 
1526 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1527 public:
1529 
1530  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1531  PatternRewriter &rewriter) const final {
1532  auto loc = op.getLoc();
1533  auto input = op.getInput();
1534  auto inputTy = cast<ShapedType>(op.getInput().getType());
1535  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1536  unsigned rank = inputTy.getRank();
1537 
1538  // This is an illegal configuration. terminate and log an error
1539  if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1540  return rewriter.notifyMatchFailure(
1541  op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1542  "currently supported");
1543  if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1544  return rewriter.notifyMatchFailure(
1545  op, "tosa.rescale requires scale32 for double_round to be true");
1546 
1547  if (!isa<IntegerType>(inputTy.getElementType()))
1548  return rewriter.notifyMatchFailure(op, "only support integer type");
1549 
1550  SmallVector<Value> dynDims;
1551  for (int i = 0; i < outputTy.getRank(); i++) {
1552  if (outputTy.isDynamicDim(i)) {
1553  dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1554  }
1555  }
1556 
1557  DenseElementsAttr shiftElems;
1558  bool isShiftConstant = false;
1559  if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
1560  isShiftConstant = true;
1561 
1562  DenseElementsAttr multiplierElems;
1563  bool isMultiplierConstant = false;
1564  if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1565  isMultiplierConstant = true;
1566 
1567  llvm::SmallVector<int32_t> shiftValues;
1568  llvm::SmallVector<int32_t> multiplierValues;
1569  bool doubleRound;
1570 
1571  if (isMultiplierConstant && isShiftConstant) {
1572  // explicit cast is required here
1573  shiftValues = llvm::to_vector(llvm::map_range(
1574  shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1575  return static_cast<int32_t>(attr.getInt());
1576  }));
1577  multiplierValues = llvm::to_vector(
1578  llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1579  [](IntegerAttr attr) -> int32_t {
1580  return static_cast<int32_t>(attr.getInt());
1581  }));
1582 
1583  // If we shift by more than the bitwidth, this just sets to 0.
1584  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1585  if (shiftValues[i] > 63) {
1586  shiftValues[i] = 0;
1587  multiplierValues[i] = 0;
1588  }
1589  }
1590  // Double round only occurs if shift is greater than 31, check that this
1591  // is ever true.
1592  doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1593  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1594  } else
1595  doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1596 
1597  RoundingMode roundingMode =
1598  doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1599 
1600  SmallVector<AffineMap> indexingMaps = {
1601  rewriter.getMultiDimIdentityMap(rank)};
1602  SmallVector<Value, 4> genericInputs = {input};
1603 
1604  // If we are rescaling per-channel then we need to store the multiplier
1605  // values in a buffer.
1606  Value multiplierConstant;
1607  int64_t multiplierArg = 0;
1608  setupLinalgGenericOpInputAndIndexingMap(
1609  rewriter, multiplierValues, genericInputs, indexingMaps,
1610  isMultiplierConstant, op, multiplierConstant, multiplierArg);
1611 
1612  // If we are rescaling per-channel then we need to store the shift
1613  // values in a buffer.
1614  Value shiftConstant;
1615  int64_t shiftArg = 0;
1616  setupLinalgGenericOpInputAndIndexingMap(
1617  rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1618  shiftConstant, shiftArg, true);
1619 
1620  // broadcastMap = affine_map<(d0, d1) -> ()>
1621  // It would affect as broadcast for scalar values in linalg::GenericOp.
1622  AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
1623  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1624  FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1625  // The inputZp and outputZp may be either constant or non-constant,
1626  // depending on whether dynamic extension is enabled.
1627  // - If the zp's are non-constant, add them as an inputs to
1628  // linalg::GenericOp by:
1629  // 1. Pushing it into 'genericInputs'.
1630  // 2. Appending a corresponding affine map to 'indexingMaps'.
1631  // - If the zp's are constant, they would be generated as arith.constant.
1632  int64_t iZpArg = 0;
1633  if (failed(maybeIZp)) {
1634  genericInputs.push_back(
1635  collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1636  indexingMaps.push_back(broadcastMap);
1637  iZpArg = indexingMaps.size() - 1;
1638  }
1639  int64_t oZpArg = 0;
1640  if (failed(maybeOZp)) {
1641  genericInputs.push_back(
1642  collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1643  indexingMaps.push_back(broadcastMap);
1644  oZpArg = indexingMaps.size() - 1;
1645  }
1646 
1647  // Indexing maps for output values.
1648  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1649 
1650  // Construct the indexing maps needed for linalg.generic ops.
1651  Value emptyTensor = tensor::EmptyOp::create(
1652  rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1653  ArrayRef<Value>({dynDims}));
1654 
1655  auto linalgOp = linalg::GenericOp::create(
1656  rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor},
1657  indexingMaps, getNParallelLoopsAttrs(rank),
1658  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1659  ValueRange blockArgs) {
1660  Value value = blockArgs[0];
1661  Type valueTy = value.getType();
1662 
1663  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1664  auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1665  nestedLoc, blockArgs, iZpArg);
1666 
1667  FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1668  auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1669  nestedLoc, blockArgs, oZpArg, true);
1670 
1671  IntegerType outIntType =
1672  cast<IntegerType>(blockArgs.back().getType());
1673  unsigned outBitWidth = outIntType.getWidth();
1674  assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1675 
1676  Value multiplier = multiplierConstant ? multiplierConstant
1677  : blockArgs[multiplierArg];
1678  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1679 
1680  if (valueTy.isUnsignedInteger()) {
1681  value = UnrealizedConversionCastOp::create(
1682  nestedBuilder, nestedLoc,
1683  nestedBuilder.getIntegerType(
1684  valueTy.getIntOrFloatBitWidth()),
1685  value)
1686  .getResult(0);
1687  }
1688  if (valueTy.getIntOrFloatBitWidth() < 32) {
1689  if (op.getInputUnsigned()) {
1690  value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1691  nestedBuilder.getI32Type(), value);
1692  } else {
1693  value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1694  nestedBuilder.getI32Type(), value);
1695  }
1696  }
1697 
1698  value =
1699  arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1700 
1701  value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1702  nestedBuilder.getI32Type(), value,
1703  multiplier, shift, roundingMode);
1704 
1705  // Move to the new zero-point.
1706  value =
1707  arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1708 
1709  // Saturate to the output size.
1710  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1711  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1712 
1713  // Unsigned integers have a difference output value.
1714  if (op.getOutputUnsigned()) {
1715  intMin = 0;
1716  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1717  }
1718 
1719  auto intMinVal = arith::ConstantOp::create(
1720  nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1721  auto intMaxVal = arith::ConstantOp::create(
1722  nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1723 
1724  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1725  nestedBuilder, /*isUnsigned=*/false);
1726 
1727  if (outIntType.getWidth() < 32) {
1728  value = arith::TruncIOp::create(
1729  nestedBuilder, nestedLoc,
1730  rewriter.getIntegerType(outIntType.getWidth()), value);
1731  }
1732 
1733  if (outIntType.isUnsignedInteger()) {
1734  value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1735  outIntType, value)
1736  .getResult(0);
1737  }
1738  linalg::YieldOp::create(nestedBuilder, loc, value);
1739  });
1740 
1741  rewriter.replaceOp(op, linalgOp->getResults());
1742  return success();
1743  }
1744 };
1745 
1746 // Handle the resize case where the input is a 1x1 image. This case
1747 // can entirely avoiding having extract operations which target much
1748 // more difficult to optimize away.
1749 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1750 public:
1752 
1753  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1754  PatternRewriter &rewriter) const final {
1755  Location loc = op.getLoc();
1756  ImplicitLocOpBuilder builder(loc, rewriter);
1757  auto input = op.getInput();
1758  auto inputTy = cast<RankedTensorType>(input.getType());
1759  auto resultTy = cast<RankedTensorType>(op.getType());
1760  const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1761 
1762  auto inputH = inputTy.getDimSize(1);
1763  auto inputW = inputTy.getDimSize(2);
1764  auto outputH = resultTy.getDimSize(1);
1765  auto outputW = resultTy.getDimSize(2);
1766 
1767  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1768  return rewriter.notifyMatchFailure(
1769  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1770 
1771  if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1772  op.getMode() != ResizeMode::BILINEAR)
1773  return rewriter.notifyMatchFailure(
1774  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1775 
1776  if (inputTy == resultTy) {
1777  rewriter.replaceOp(op, input);
1778  return success();
1779  }
1780 
1781  SmallVector<int64_t> scale;
1782  if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
1783  return failure();
1784  }
1785 
1786  // Collapse the unit width and height away.
1787  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1788  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1789  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1790  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1791  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1792 
1793  auto collapseTy =
1794  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1795  inputTy.getElementType());
1796  Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1797  reassociationMap);
1798 
1799  // Get any dynamic shapes that appear in the input format.
1800  llvm::SmallVector<Value> outputDynSize;
1801  if (inputTy.isDynamicDim(0))
1802  outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1803  if (inputTy.isDynamicDim(3))
1804  outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1805 
1806  // Generate the elementwise operation for casting scaling the input value.
1807  auto genericTy = collapseTy.clone(resultTy.getElementType());
1808  Value empty =
1809  tensor::EmptyOp::create(builder, genericTy.getShape(),
1810  resultTy.getElementType(), outputDynSize);
1811  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1812  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1813  utils::IteratorType::parallel);
1814 
1815  auto generic = linalg::GenericOp::create(
1816  builder, genericTy, ValueRange{collapse}, ValueRange{empty},
1817  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1818  [=](OpBuilder &b, Location loc, ValueRange args) {
1819  Value value = args[0];
1820  // This is the quantized case.
1821  if (inputTy.getElementType() != resultTy.getElementType()) {
1822  value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
1823  value);
1824 
1825  if (isBilinear && scale[0] != 0) {
1826  Value scaleY = arith::ConstantOp::create(
1827  b, loc, b.getI32IntegerAttr(scale[0]));
1828  value = arith::MulIOp::create(b, loc, value, scaleY);
1829  }
1830 
1831  if (isBilinear && scale[2] != 0) {
1832  Value scaleX = arith::ConstantOp::create(
1833  b, loc, b.getI32IntegerAttr(scale[2]));
1834  value = arith::MulIOp::create(b, loc, value, scaleX);
1835  }
1836  }
1837 
1838  linalg::YieldOp::create(b, loc, value);
1839  });
1840 
1841  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1842  op, resultTy, generic.getResults()[0], reassociationMap);
1843  return success();
1844  }
1845 };
1846 
1847 // TOSA resize with width or height of 1 may be broadcasted to a wider
1848 // dimension. This is done by materializing a new tosa.resize without
1849 // the broadcasting behavior, and an explicit broadcast afterwards.
1850 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1851 public:
1853 
1854  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1855  PatternRewriter &rewriter) const final {
1856  Location loc = op.getLoc();
1857  ImplicitLocOpBuilder builder(loc, rewriter);
1858  auto input = op.getInput();
1859  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1860  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1861 
1862  if (!inputTy || !resultTy)
1863  return rewriter.notifyMatchFailure(op,
1864  "requires ranked input/output types");
1865 
1866  auto batch = inputTy.getDimSize(0);
1867  auto channels = inputTy.getDimSize(3);
1868  auto inputH = inputTy.getDimSize(1);
1869  auto inputW = inputTy.getDimSize(2);
1870  auto outputH = resultTy.getDimSize(1);
1871  auto outputW = resultTy.getDimSize(2);
1872 
1873  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1874  return rewriter.notifyMatchFailure(
1875  op, "tosa.resize has no broadcasting behavior");
1876 
1877  // For any dimension that is broadcastable we generate a width of 1
1878  // on the output.
1879  llvm::SmallVector<int64_t> resizeShape;
1880  resizeShape.push_back(batch);
1881  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1882  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1883  resizeShape.push_back(channels);
1884 
1885  auto resizeTy = resultTy.clone(resizeShape);
1886  auto resize =
1887  tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1888  op.getOffset(), op.getBorder(), op.getMode());
1889 
1890  // Collapse an unit result dims.
1891  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1892  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1893  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1894  if (inputH != 1)
1895  reassociationMap.push_back({});
1896  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1897  if (inputW != 1)
1898  reassociationMap.push_back({});
1899  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1900 
1901  llvm::SmallVector<int64_t> collapseShape = {batch};
1902  if (inputH != 1)
1903  collapseShape.push_back(outputH);
1904  if (inputW != 1)
1905  collapseShape.push_back(outputW);
1906  collapseShape.push_back(channels);
1907 
1908  auto collapseTy = resultTy.clone(collapseShape);
1909  Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1910  resize, reassociationMap);
1911 
1912  // Broadcast the collapsed shape to the output result.
1913  llvm::SmallVector<Value> outputDynSize;
1914  if (inputTy.isDynamicDim(0))
1915  outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1916  if (inputTy.isDynamicDim(3))
1917  outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1918 
1919  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1920  utils::IteratorType::parallel);
1921  Value empty = tensor::EmptyOp::create(
1922  builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1923 
1924  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1925  if (inputH != 1)
1926  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1927  if (inputW != 1)
1928  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1929  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1930 
1931  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1932  inputExprs, rewriter.getContext());
1933 
1934  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1935  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1936  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1937  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1938  [=](OpBuilder &b, Location loc, ValueRange args) {
1939  Value value = args[0];
1940  linalg::YieldOp::create(b, loc, value);
1941  });
1942 
1943  return success();
1944  }
1945 };
1946 
1947 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1948 public:
1950 
1951  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1952  PatternRewriter &rewriter) const final {
1953  Location loc = op.getLoc();
1954  ImplicitLocOpBuilder b(loc, rewriter);
1955  auto input = op.getInput();
1956  auto inputTy = cast<ShapedType>(input.getType());
1957  auto resultTy = cast<ShapedType>(op.getType());
1958  auto resultETy = resultTy.getElementType();
1959 
1960  bool floatingPointMode = isa<FloatType>(resultETy);
1961  auto floatTy = resultETy;
1962 
1963  auto imageH = inputTy.getShape()[1];
1964  auto imageW = inputTy.getShape()[2];
1965 
1966  auto dynamicDimsOr =
1967  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1968  if (!dynamicDimsOr.has_value())
1969  return rewriter.notifyMatchFailure(
1970  op, "unable to get dynamic dimensions of tosa.resize");
1971 
1972  if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1973  op.getMode() != ResizeMode::BILINEAR)
1974  return rewriter.notifyMatchFailure(
1975  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1976 
1977  SmallVector<AffineMap, 2> affineMaps = {
1978  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1979  auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
1980  resultETy, *dynamicDimsOr);
1981  auto genericOp = linalg::GenericOp::create(
1982  b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1983  getNParallelLoopsAttrs(resultTy.getRank()));
1984  Value resize = genericOp.getResult(0);
1985 
1986  {
1987  OpBuilder::InsertionGuard regionGuard(b);
1988  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1989  TypeRange({resultETy}), loc);
1990  Value batch = linalg::IndexOp::create(b, 0);
1991  Value y = linalg::IndexOp::create(b, 1);
1992  Value x = linalg::IndexOp::create(b, 2);
1993  Value channel = linalg::IndexOp::create(b, 3);
1994 
1995  Value zeroI32 =
1996  arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
1997  Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
1998  Value hMax =
1999  arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
2000  Value wMax =
2001  arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
2002 
2003  Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
2004  Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
2005 
2006  SmallVector<int64_t> scale, offset, border;
2007  if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
2008  !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
2009  !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
2010  return rewriter.notifyMatchFailure(
2011  op, "tosa.resize scale/offset/border should have compile time "
2012  "constant values.");
2013  }
2014 
2015  Value yScaleN, yScaleD, xScaleN, xScaleD;
2016  yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
2017  yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
2018  xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
2019  xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
2020 
2021  Value yOffset, xOffset, yBorder, xBorder;
2022  yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
2023  xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
2024  yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
2025  xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
2026 
2027  // Compute the ix and dx values for both the X and Y dimensions.
2028  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
2029  Value scaleN, Value scaleD, Value offset,
2030  int size, ImplicitLocOpBuilder &b) {
2031  if (size == 1) {
2032  index = zeroI32;
2033  delta = zeroFp;
2034  return;
2035  }
2036  // x = x * scale_d + offset;
2037  // ix = floor(x / scale_n)
2038  Value val = arith::MulIOp::create(b, in, scaleD);
2039  val = arith::AddIOp::create(b, val, offset);
2040  index = arith::FloorDivSIOp::create(b, val, scaleN);
2041 
2042  // rx = x % scale_n
2043  // dx = rx / scale_n
2044  Value r = arith::RemSIOp::create(b, val, scaleN);
2045  Value rFp = arith::SIToFPOp::create(b, floatTy, r);
2046  Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
2047  delta = arith::DivFOp::create(b, rFp, scaleNfp);
2048  };
2049 
2050  // Compute the ix and dx values for the X and Y dimensions - int case.
2051  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
2052  Value scaleN, Value scaleD, Value offset,
2053  int size, ImplicitLocOpBuilder &b) {
2054  if (size == 1) {
2055  index = zeroI32;
2056  delta = zeroI32;
2057  return;
2058  }
2059  // x = x * scale_d + offset;
2060  // ix = floor(x / scale_n)
2061  // dx = x - ix * scale_n;
2062  Value val = arith::MulIOp::create(b, in, scaleD);
2063  val = arith::AddIOp::create(b, val, offset);
2064  index = arith::DivSIOp::create(b, val, scaleN);
2065  delta = arith::MulIOp::create(b, index, scaleN);
2066  delta = arith::SubIOp::create(b, val, delta);
2067  };
2068 
2069  Value ix, iy, dx, dy;
2070  if (floatingPointMode) {
2071  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2072  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2073  } else {
2074  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2075  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2076  }
2077 
2078  if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2079  auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2080 
2081  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
2082  Value max, int size,
2083  ImplicitLocOpBuilder &b) -> Value {
2084  if (size == 1) {
2085  return arith::ConstantIndexOp::create(b, 0);
2086  }
2087 
2088  Value pred;
2089  if (floatingPointMode) {
2090  auto h =
2091  arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
2092  pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
2093  } else {
2094  Value dvalDouble = arith::ShLIOp::create(b, dval, one);
2095  pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
2096  dvalDouble, scale);
2097  }
2098 
2099  auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
2100  val = arith::AddIOp::create(b, val, offset);
2101  val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
2102  return arith::IndexCastOp::create(b, b.getIndexType(), val);
2103  };
2104 
2105  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
2106  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
2107 
2108  Value result = tensor::ExtractOp::create(
2109  b, input, ValueRange{batch, iy, ix, channel});
2110 
2111  linalg::YieldOp::create(b, result);
2112  } else {
2113  // The mode here must be BILINEAR.
2114  assert(op.getMode() == ResizeMode::BILINEAR);
2115 
2116  auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2117 
2118  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
2120  val0 = in;
2121  val1 = arith::AddIOp::create(b, val0, oneVal);
2122  val0 =
2123  clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
2124  val1 =
2125  clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
2126  val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
2127  val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
2128  };
2129 
2130  // Linalg equivalent to the section below:
2131  // int16_t iy0 = apply_max(iy, 0);
2132  // int16_t iy1 = apply_min(iy + 1, IH - 1);
2133  // int16_t ix0 = apply_max(ix, 0);
2134  // int16_t ix1 = apply_min(ix + 1, IW - 1);
2135  Value x0, x1, y0, y1;
2136  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
2137  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
2138 
2139  Value y0x0 = tensor::ExtractOp::create(
2140  b, input, ValueRange{batch, y0, x0, channel});
2141  Value y0x1 = tensor::ExtractOp::create(
2142  b, input, ValueRange{batch, y0, x1, channel});
2143  Value y1x0 = tensor::ExtractOp::create(
2144  b, input, ValueRange{batch, y1, x0, channel});
2145  Value y1x1 = tensor::ExtractOp::create(
2146  b, input, ValueRange{batch, y1, x1, channel});
2147 
2148  if (floatingPointMode) {
2149  auto oneVal =
2150  arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
2151  auto interpolate = [&](Value val0, Value val1, Value delta,
2152  int inputSize,
2153  ImplicitLocOpBuilder &b) -> Value {
2154  if (inputSize == 1)
2155  return val0;
2156  Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
2157  Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
2158  Value mul1 = arith::MulFOp::create(b, val1, delta);
2159  return arith::AddFOp::create(b, mul0, mul1);
2160  };
2161 
2162  // Linalg equivalent to the section below:
2163  // topAcc = v00 * (unit_x - dx);
2164  // topAcc += v01 * dx;
2165  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
2166 
2167  // Linalg equivalent to the section below:
2168  // bottomAcc = v10 * (unit_x - dx);
2169  // bottomAcc += v11 * dx;
2170  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
2171 
2172  // Linalg equivalent to the section below:
2173  // result = topAcc * (unit_y - dy) + bottomAcc * dy
2174  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
2175  linalg::YieldOp::create(b, result);
2176  } else {
2177  // Perform in quantized space.
2178  y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
2179  y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
2180  y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
2181  y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
2182 
2183  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
2184  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2185  dx = arith::ExtSIOp::create(b, resultETy, dx);
2186  dy = arith::ExtSIOp::create(b, resultETy, dy);
2187  }
2188 
2189  Value yScaleNExt = yScaleN;
2190  Value xScaleNExt = xScaleN;
2191 
2192  const int64_t scaleBitwidth =
2193  xScaleN.getType().getIntOrFloatBitWidth();
2194  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2195  yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
2196  xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
2197  }
2198 
2199  auto interpolate = [](Value val0, Value val1, Value weight1,
2200  Value scale, int inputSize,
2201  ImplicitLocOpBuilder &b) -> Value {
2202  if (inputSize == 1)
2203  return arith::MulIOp::create(b, val0, scale);
2204  Value weight0 = arith::SubIOp::create(b, scale, weight1);
2205  Value mul0 = arith::MulIOp::create(b, val0, weight0);
2206  Value mul1 = arith::MulIOp::create(b, val1, weight1);
2207  return arith::AddIOp::create(b, mul0, mul1);
2208  };
2209 
2210  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2211  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2212  Value result =
2213  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2214  linalg::YieldOp::create(b, result);
2215  }
2216  }
2217  }
2218 
2219  rewriter.replaceOp(op, resize);
2220  return success();
2221  }
2222 };
2223 
2224 // At the codegen level any identity operations should be removed. Any cases
2225 // where identity is load-bearing (e.g. cross device computation) should be
2226 // handled before lowering to codegen.
2227 template <typename SrcOp>
2228 class IdentityNConverter : public OpRewritePattern<SrcOp> {
2229 public:
2231 
2232  LogicalResult matchAndRewrite(SrcOp op,
2233  PatternRewriter &rewriter) const final {
2234  rewriter.replaceOp(op, op.getOperation()->getOperands());
2235  return success();
2236  }
2237 };
2238 
2239 template <typename SrcOp>
2240 class ReduceConverter : public OpRewritePattern<SrcOp> {
2241 public:
2243 
2244  LogicalResult matchAndRewrite(SrcOp reduceOp,
2245  PatternRewriter &rewriter) const final {
2246  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2247  }
2248 };
2249 
2250 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2251 public:
2253 
2254  LogicalResult matchAndRewrite(tosa::ReverseOp op,
2255  PatternRewriter &rewriter) const final {
2256  auto loc = op.getLoc();
2257  Value input = op.getInput1();
2258  auto inputTy = cast<ShapedType>(input.getType());
2259  auto resultTy = cast<ShapedType>(op.getType());
2260  auto axis = op.getAxis();
2261 
2262  SmallVector<Value> dynDims;
2263  for (int i = 0; i < inputTy.getRank(); i++) {
2264  if (inputTy.isDynamicDim(i)) {
2265  dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2266  }
2267  }
2268 
2269  Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2270 
2271  // First fill the output buffer with the init value.
2272  auto emptyTensor = tensor::EmptyOp::create(
2273  rewriter, loc, inputTy.getShape(),
2274  inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2275  .getResult();
2276  SmallVector<AffineMap, 2> affineMaps = {
2277  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2278 
2279  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2280  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
2281  getNParallelLoopsAttrs(resultTy.getRank()),
2282  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2283  llvm::SmallVector<Value> indices;
2284  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2285  Value index =
2286  linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2287  if (i == axis) {
2288  auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1);
2289  auto sizeMinusOne =
2290  arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2291  index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2292  index);
2293  }
2294 
2295  indices.push_back(index);
2296  }
2297 
2298  auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2299  input, indices);
2300  linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2301  extract.getResult());
2302  });
2303  return success();
2304  }
2305 };
2306 
2307 // This converter translate a tile operation to a reshape, broadcast, reshape.
2308 // The first reshape minimally expands each tiled dimension to include a
2309 // proceding size-1 dim. This dim is then broadcasted to the appropriate
2310 // multiple.
2311 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2313 
2314  LogicalResult
2315  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2316  ConversionPatternRewriter &rewriter) const override {
2317  auto loc = op.getLoc();
2318  auto input = op.getInput1();
2319  auto inputTy = cast<ShapedType>(input.getType());
2320  auto inputShape = inputTy.getShape();
2321  auto resultTy = cast<ShapedType>(op.getType());
2322  auto elementTy = inputTy.getElementType();
2323  int64_t rank = inputTy.getRank();
2324 
2325  SmallVector<int64_t> multiples;
2326  if (failed(op.getConstantMultiples(multiples)))
2327  return failure();
2328 
2329  // Broadcast the newly added dimensions to their appropriate multiple.
2330  SmallVector<int64_t, 2> genericShape;
2331  for (int i = 0; i < rank; i++) {
2332  int64_t dim = multiples[i];
2333  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2334  genericShape.push_back(inputShape[i]);
2335  }
2336 
2337  SmallVector<Value> dynDims;
2338  for (int i = 0; i < inputTy.getRank(); i++) {
2339  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2340  dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2341  }
2342  }
2343 
2344  auto emptyTensor = tensor::EmptyOp::create(
2345  rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2346 
2347  // We needs to map the input shape to the non-broadcasted dimensions.
2348  SmallVector<AffineExpr, 4> dimExprs;
2349  dimExprs.reserve(rank);
2350  for (unsigned i = 0; i < rank; ++i)
2351  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2352 
2353  auto readAffineMap =
2354  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2355  rewriter.getContext());
2356 
2357  SmallVector<AffineMap, 2> affineMaps = {
2358  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2359 
2360  auto genericOp = linalg::GenericOp::create(
2361  rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2362  ValueRange{emptyTensor}, affineMaps,
2363  getNParallelLoopsAttrs(genericShape.size()),
2364  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2365  linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2366  });
2367 
2368  auto shapeValue = getTosaConstShape(
2369  rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
2370  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2371  op, resultTy, genericOp.getResult(0), shapeValue);
2372  return success();
2373  }
2374 };
2375 
2376 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2377 // op, producing two output buffers.
2378 //
2379 // The first output buffer contains the index of the found maximum value. It is
2380 // initialized to 0 and is resulting integer type.
2381 //
2382 // The second output buffer contains the maximum value found. It is initialized
2383 // to the minimum representable value of the input element type. After being
2384 // populated by indexed_generic, this buffer is disgarded as only the index is
2385 // requested.
2386 //
2387 // The indexed_generic op updates both the maximum value and index if the
2388 // current value exceeds the running max.
2389 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2390 public:
2392 
2393  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2394  PatternRewriter &rewriter) const final {
2395  auto loc = argmaxOp.getLoc();
2396  Value input = argmaxOp.getInput();
2397  auto inputTy = cast<ShapedType>(input.getType());
2398  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2399  auto inElementTy = inputTy.getElementType();
2400  auto outElementTy = resultTy.getElementType();
2401  int axis = argmaxOp.getAxis();
2402  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2403 
2404  if (!isa<IntegerType>(outElementTy))
2405  return rewriter.notifyMatchFailure(
2406  argmaxOp,
2407  "tosa.arg_max to linalg.* requires integer-like result type");
2408 
2409  SmallVector<Value> dynDims;
2410  for (int i = 0; i < inputTy.getRank(); i++) {
2411  if (inputTy.isDynamicDim(i) && i != axis) {
2412  dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2413  }
2414  }
2415 
2416  // First fill the output buffer for the index.
2417  auto emptyTensorIdx =
2418  tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2419  outElementTy, dynDims)
2420  .getResult();
2421  auto fillValueIdx = arith::ConstantOp::create(
2422  rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2423  auto filledTensorIdx =
2424  linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
2425  ValueRange{emptyTensorIdx})
2426  .result();
2427 
2428  // Second fill the output buffer for the running max.
2429  auto emptyTensorMax =
2430  tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2431  dynDims)
2432  .getResult();
2433  auto fillValueMaxAttr =
2434  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2435 
2436  if (!fillValueMaxAttr)
2437  return rewriter.notifyMatchFailure(
2438  argmaxOp, "unsupported tosa.argmax element type");
2439 
2440  auto fillValueMax =
2441  arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2442  auto filledTensorMax =
2443  linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
2444  ValueRange{emptyTensorMax})
2445  .result();
2446 
2447  // We need to reduce along the arg-max axis, with parallel operations along
2448  // the rest.
2450  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2451  iteratorTypes[axis] = utils::IteratorType::reduction;
2452 
2453  SmallVector<AffineExpr, 2> srcExprs;
2454  SmallVector<AffineExpr, 2> dstExprs;
2455  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2456  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2457  if (axis != i)
2458  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2459  }
2460 
2461  bool didEncounterError = false;
2462  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2463  rewriter.getContext());
2464  auto linalgOp = linalg::GenericOp::create(
2465  rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2466  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2467  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2468  ValueRange blockArgs) {
2469  auto newValue = blockArgs[0];
2470  auto oldIndex = blockArgs[1];
2471  auto oldValue = blockArgs[2];
2472 
2473  Value newIndex = arith::IndexCastOp::create(
2474  rewriter, nestedLoc, oldIndex.getType(),
2475  linalg::IndexOp::create(rewriter, loc, axis));
2476 
2477  Value predicate;
2478  if (isa<FloatType>(inElementTy)) {
2479  if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2480  // Only update index & max value for non NaN values. If all
2481  // values are NaNs, the initial index will be return which is 0.
2482  predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2483  arith::CmpFPredicate::OGT,
2484  newValue, oldValue);
2485  } else {
2486  // Update max value if either of the following is true:
2487  // - new value is bigger
2488  // - cur max is not NaN and new value is NaN
2489  Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2490  arith::CmpFPredicate::UGT,
2491  newValue, oldValue);
2492  Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2493  arith::CmpFPredicate::ORD,
2494  oldValue, oldValue);
2495  predicate = arith::AndIOp::create(
2496  rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2497  }
2498  } else if (isa<IntegerType>(inElementTy)) {
2499  predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2500  arith::CmpIPredicate::sgt,
2501  newValue, oldValue);
2502  } else {
2503  didEncounterError = true;
2504  return;
2505  }
2506 
2507  auto resultMax = arith::SelectOp::create(
2508  rewriter, nestedLoc, predicate, newValue, oldValue);
2509  auto resultIndex = arith::SelectOp::create(
2510  rewriter, nestedLoc, predicate, newIndex, oldIndex);
2511  linalg::YieldOp::create(nestedBuilder, nestedLoc,
2512  ValueRange({resultIndex, resultMax}));
2513  });
2514 
2515  if (didEncounterError)
2516  return rewriter.notifyMatchFailure(
2517  argmaxOp, "unsupported tosa.argmax element type");
2518 
2519  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2520  return success();
2521  }
2522 };
2523 
2524 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2525 public:
2527  LogicalResult
2528  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2529  ConversionPatternRewriter &rewriter) const final {
2530  auto input = adaptor.getOperands()[0];
2531  auto indices = adaptor.getOperands()[1];
2532 
2533  auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2534  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2535  if (!valuesTy || !resultTy)
2536  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2537 
2538  auto dynamicDims = inferDynamicDimsForGather(
2539  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2540 
2541  auto resultElementTy = resultTy.getElementType();
2542 
2543  auto loc = op.getLoc();
2544  auto emptyTensor =
2545  tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2546  resultElementTy, dynamicDims)
2547  .getResult();
2548 
2549  SmallVector<AffineMap, 2> affineMaps = {
2551  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2552  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2553  rewriter.getContext()),
2554  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2555 
2556  auto genericOp = linalg::GenericOp::create(
2557  rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2558  ValueRange{emptyTensor}, affineMaps,
2559  getNParallelLoopsAttrs(resultTy.getRank()),
2560  [&](OpBuilder &b, Location loc, ValueRange args) {
2561  auto indexValue = args[0];
2562  auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2563  Value index1 = arith::IndexCastOp::create(
2564  rewriter, loc, rewriter.getIndexType(), indexValue);
2565  auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2566  Value extract = tensor::ExtractOp::create(
2567  rewriter, loc, input, ValueRange{index0, index1, index2});
2568  linalg::YieldOp::create(rewriter, loc, extract);
2569  });
2570  rewriter.replaceOp(op, genericOp.getResult(0));
2571  return success();
2572  }
2573 
2574  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2575  Location loc,
2576  Value values,
2577  Value indices) {
2578  llvm::SmallVector<Value> results;
2579 
2580  auto addDynamicDimension = [&](Value source, int64_t dim) {
2581  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2582  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2583  results.push_back(dimValue);
2584  };
2585 
2586  addDynamicDimension(values, 0);
2587  addDynamicDimension(indices, 1);
2588  addDynamicDimension(values, 2);
2589  return results;
2590  }
2591 };
2592 
2593 // Lowerings the TableOp to a series of gathers and numerica operations. This
2594 // includes interpolation between the high/low values. For the I8 varient, this
2595 // simplifies to a single gather operation.
2596 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2597 public:
2599 
2600  LogicalResult matchAndRewrite(tosa::TableOp op,
2601  PatternRewriter &rewriter) const final {
2602  auto loc = op.getLoc();
2603  Value input = op.getInput1();
2604  Value table = op.getTable();
2605  auto inputTy = cast<ShapedType>(input.getType());
2606  auto tableTy = cast<ShapedType>(table.getType());
2607  auto resultTy = cast<ShapedType>(op.getType());
2608 
2609  auto inputElementTy = inputTy.getElementType();
2610  auto tableElementTy = tableTy.getElementType();
2611  auto resultElementTy = resultTy.getElementType();
2612 
2613  SmallVector<Value> dynDims;
2614  for (int i = 0; i < resultTy.getRank(); ++i) {
2615  if (inputTy.isDynamicDim(i)) {
2616  dynDims.push_back(
2617  tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2618  }
2619  }
2620 
2621  auto emptyTensor =
2622  tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2623  resultElementTy, dynDims)
2624  .getResult();
2625 
2626  SmallVector<AffineMap, 2> affineMaps = {
2627  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2628  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2629 
2630  auto genericOp = linalg::GenericOp::create(
2631  rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor},
2632  affineMaps, getNParallelLoopsAttrs(resultTy.getRank()));
2633  rewriter.replaceOp(op, genericOp.getResult(0));
2634 
2635  {
2636  OpBuilder::InsertionGuard regionGuard(rewriter);
2637  Block *block = rewriter.createBlock(
2638  &genericOp.getRegion(), genericOp.getRegion().end(),
2639  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2640 
2641  auto inputValue = block->getArgument(0);
2642  rewriter.setInsertionPointToStart(block);
2643  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2644  resultElementTy.isInteger(8)) {
2645  Value index = arith::IndexCastOp::create(
2646  rewriter, loc, rewriter.getIndexType(), inputValue);
2647  Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128);
2648  index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2649  index, offset);
2650  Value extract =
2651  tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2652  linalg::YieldOp::create(rewriter, loc, extract);
2653  return success();
2654  }
2655 
2656  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2657  resultElementTy.isInteger(32)) {
2658  Value extend = arith::ExtSIOp::create(
2659  rewriter, loc, rewriter.getI32Type(), inputValue);
2660 
2661  auto offset = arith::ConstantOp::create(
2662  rewriter, loc, rewriter.getI32IntegerAttr(32768));
2663  auto seven = arith::ConstantOp::create(rewriter, loc,
2664  rewriter.getI32IntegerAttr(7));
2665  auto one = arith::ConstantOp::create(rewriter, loc,
2666  rewriter.getI32IntegerAttr(1));
2667  auto b1111111 = arith::ConstantOp::create(
2668  rewriter, loc, rewriter.getI32IntegerAttr(127));
2669 
2670  // Compute the index and fractional part from the input value:
2671  // value = value + 32768
2672  // index = value >> 7;
2673  // fraction = 0x01111111 & value
2674  auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2675  Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2676  Value fraction =
2677  arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2678 
2679  // Extract the base and next values from the table.
2680  // base = (int32_t) table[index];
2681  // next = (int32_t) table[index + 1];
2682  Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2683 
2684  index = arith::IndexCastOp::create(rewriter, loc,
2685  rewriter.getIndexType(), index);
2686  indexPlusOne = arith::IndexCastOp::create(
2687  rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2688 
2689  Value base =
2690  tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2691  Value next = tensor::ExtractOp::create(rewriter, loc, table,
2692  ValueRange{indexPlusOne});
2693 
2694  base =
2695  arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2696  next =
2697  arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2698 
2699  // Use the fractional part to interpolate between the input values:
2700  // result = (base << 7) + (next - base) * fraction
2701  Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2702  Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2703  Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2704  Value result =
2705  arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2706 
2707  linalg::YieldOp::create(rewriter, loc, result);
2708 
2709  return success();
2710  }
2711  }
2712 
2713  return rewriter.notifyMatchFailure(
2714  op, "unable to create body for tosa.table op");
2715  }
2716 };
2717 
2718 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2720 
2721  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2722 
2723  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2724  OpFoldResult ofr) {
2725  auto one = arith::ConstantIndexOp::create(builder, loc, 1);
2726  auto two = arith::ConstantIndexOp::create(builder, loc, 2);
2727 
2728  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2729  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2730  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2731  return getAsOpFoldResult(plusOne);
2732  }
2733 
2734  static RankedTensorType
2735  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2736  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2737  // Get [N, H, W]
2738  auto dims = tensor::getMixedSizes(builder, loc, input);
2739 
2740  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2741  // output tensors.
2742  dims[2] = halfPlusOne(builder, loc, dims[2]);
2743 
2744  llvm::SmallVector<int64_t, 3> staticSizes;
2745  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2746 
2747  auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2748  return RankedTensorType::get(staticSizes, elementType);
2749  }
2750 
2751  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2752  RankedTensorType type,
2753  llvm::ArrayRef<Value> dynamicSizes) {
2754  auto emptyTensor =
2755  tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2756  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2757  auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2758  auto filledTensor =
2759  linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
2760  ValueRange{emptyTensor})
2761  .result();
2762  return filledTensor;
2763  }
2764 
2765  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2766  FloatType type, Value value) {
2767  auto integerVal = arith::IndexCastUIOp::create(
2768  builder, loc,
2769  type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2770  : builder.getI32Type(),
2771  value);
2772 
2773  return arith::UIToFPOp::create(builder, loc, type, integerVal);
2774  }
2775 
2776  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2777  FloatType type, int64_t index) {
2778  auto indexVal = linalg::IndexOp::create(builder, loc, index);
2779  return castIndexToFloat(builder, loc, type, indexVal);
2780  }
2781 
2782  template <typename... Args>
2783  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2784  Args... args) {
2785  return {builder.getAffineDimExpr(args)...};
2786  }
2787 
2788  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2789  PatternRewriter &rewriter) const override {
2790  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2791  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2792  return rewriter.notifyMatchFailure(rfft2d,
2793  "only supports ranked tensors");
2794  }
2795 
2796  auto loc = rfft2d.getLoc();
2797  auto input = rfft2d.getInputReal();
2798  auto elementType =
2799  dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2800  if (!elementType)
2801  return rewriter.notifyMatchFailure(rfft2d,
2802  "only supports float element types");
2803 
2804  // Compute the output type and set of dynamic sizes
2805  llvm::SmallVector<Value> dynamicSizes;
2806  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2807 
2808  // Iterator types for the linalg.generic implementation
2810  utils::IteratorType::parallel, utils::IteratorType::parallel,
2811  utils::IteratorType::parallel, utils::IteratorType::reduction,
2812  utils::IteratorType::reduction};
2813 
2814  // Inputs/outputs to the linalg.generic implementation
2815  llvm::SmallVector<Value> genericOpInputs = {input};
2816  llvm::SmallVector<Value> genericOpOutputs = {
2817  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2818  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2819 
2820  // Indexing maps for input and output tensors
2821  auto indexingMaps = AffineMap::inferFromExprList(
2822  llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2823  affineDimsExpr(rewriter, 0, 1, 2),
2824  affineDimsExpr(rewriter, 0, 1, 2)},
2825  rewriter.getContext());
2826 
2827  // Width and height dimensions of the original input.
2828  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2829  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2830 
2831  // Constants and dimension sizes
2832  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2833  auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2834  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2835  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2836 
2837  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2838  Value valReal = args[0];
2839  Value sumReal = args[1];
2840  Value sumImag = args[2];
2841 
2842  // Indices for angle computation
2843  Value oy = linalg::IndexOp::create(builder, loc, 1);
2844  Value ox = linalg::IndexOp::create(builder, loc, 2);
2845  Value iy = linalg::IndexOp::create(builder, loc, 3);
2846  Value ix = linalg::IndexOp::create(builder, loc, 4);
2847 
2848  // Calculating angle without integer parts of components as sin/cos are
2849  // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2850  // / W);
2851  auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2852  auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2853 
2854  auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2855  auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2856 
2857  auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2858  auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2859 
2860  auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2861  auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2862  auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2863  auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2864 
2865  // realComponent = valReal * cos(angle)
2866  // imagComponent = valReal * sin(angle)
2867  auto cosAngle = math::CosOp::create(builder, loc, angle);
2868  auto sinAngle = math::SinOp::create(builder, loc, angle);
2869  auto realComponent =
2870  arith::MulFOp::create(builder, loc, valReal, cosAngle);
2871  auto imagComponent =
2872  arith::MulFOp::create(builder, loc, valReal, sinAngle);
2873 
2874  // outReal = sumReal + realComponent
2875  // outImag = sumImag - imagComponent
2876  auto outReal =
2877  arith::AddFOp::create(builder, loc, sumReal, realComponent);
2878  auto outImag =
2879  arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2880 
2881  linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
2882  };
2883 
2884  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2885  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2886  indexingMaps, iteratorTypes, buildBody);
2887 
2888  return success();
2889  }
2890 };
2891 
2892 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2894 
2895  LogicalResult matchAndRewrite(FFT2dOp fft2d,
2896  PatternRewriter &rewriter) const override {
2897  if (!llvm::all_of(fft2d->getOperandTypes(),
2898  RFFT2dConverter::isRankedTensor) ||
2899  !llvm::all_of(fft2d->getResultTypes(),
2900  RFFT2dConverter::isRankedTensor)) {
2901  return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2902  }
2903 
2904  Location loc = fft2d.getLoc();
2905  Value input_real = fft2d.getInputReal();
2906  Value input_imag = fft2d.getInputImag();
2907  BoolAttr inverse = fft2d.getInverseAttr();
2908 
2909  auto real_el_ty = cast<FloatType>(
2910  cast<ShapedType>(input_real.getType()).getElementType());
2911  [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2912  cast<ShapedType>(input_imag.getType()).getElementType());
2913 
2914  assert(real_el_ty == imag_el_ty);
2915 
2916  // Compute the output type and set of dynamic sizes
2917  SmallVector<Value> dynamicSizes;
2918 
2919  // Get [N, H, W]
2920  auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2921 
2922  SmallVector<int64_t, 3> staticSizes;
2923  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2924 
2925  auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2926 
2927  // Iterator types for the linalg.generic implementation
2928  SmallVector<utils::IteratorType, 5> iteratorTypes = {
2929  utils::IteratorType::parallel, utils::IteratorType::parallel,
2930  utils::IteratorType::parallel, utils::IteratorType::reduction,
2931  utils::IteratorType::reduction};
2932 
2933  // Inputs/outputs to the linalg.generic implementation
2934  SmallVector<Value> genericOpInputs = {input_real, input_imag};
2935  SmallVector<Value> genericOpOutputs = {
2936  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2937  dynamicSizes),
2938  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2939  dynamicSizes)};
2940 
2941  // Indexing maps for input and output tensors
2942  auto indexingMaps = AffineMap::inferFromExprList(
2943  ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2944  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2945  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2946  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2947  rewriter.getContext());
2948 
2949  // Width and height dimensions of the original input.
2950  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2951  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2952 
2953  // Constants and dimension sizes
2954  auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2955  auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2956  Value constH =
2957  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2958  Value constW =
2959  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2960 
2961  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2962  Value valReal = args[0];
2963  Value valImag = args[1];
2964  Value sumReal = args[2];
2965  Value sumImag = args[3];
2966 
2967  // Indices for angle computation
2968  Value oy = linalg::IndexOp::create(builder, loc, 1);
2969  Value ox = linalg::IndexOp::create(builder, loc, 2);
2970  Value iy = linalg::IndexOp::create(builder, loc, 3);
2971  Value ix = linalg::IndexOp::create(builder, loc, 4);
2972 
2973  // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2974  // ox) % W ) / W);
2975  auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2976  auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2977 
2978  auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2979  auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2980 
2981  auto iyRemFloat =
2982  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2983  auto ixRemFloat =
2984  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2985 
2986  auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2987  auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2988 
2989  auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2990  auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2991 
2992  if (inverse.getValue()) {
2993  angle = arith::MulFOp::create(
2994  builder, loc, angle,
2995  arith::ConstantOp::create(rewriter, loc,
2996  rewriter.getFloatAttr(real_el_ty, -1.0)));
2997  }
2998 
2999  // realComponent = val_real * cos(a) + val_imag * sin(a);
3000  // imagComponent = -val_real * sin(a) + val_imag * cos(a);
3001  auto cosAngle = math::CosOp::create(builder, loc, angle);
3002  auto sinAngle = math::SinOp::create(builder, loc, angle);
3003 
3004  auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
3005  auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
3006  auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
3007 
3008  auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
3009  auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3010 
3011  auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3012 
3013  // outReal = sumReal + realComponent
3014  // outImag = sumImag - imagComponent
3015  auto outReal =
3016  arith::AddFOp::create(builder, loc, sumReal, realComponent);
3017  auto outImag =
3018  arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3019 
3020  linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
3021  };
3022 
3023  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
3024  fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3025  indexingMaps, iteratorTypes, buildBody);
3026 
3027  return success();
3028  }
3029 };
3030 
3031 } // namespace
3032 
3034  const TypeConverter &converter, RewritePatternSet *patterns) {
3035 
3036  // We have multiple resize coverters to handle degenerate cases.
3037  patterns->add<GenericResizeConverter>(patterns->getContext(),
3038  /*benefit=*/100);
3039  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
3040  /*benefit=*/200);
3041  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
3042  /*benefit=*/300);
3043 
3044  patterns->add<
3045  // clang-format off
3046  PointwiseConverter<tosa::AddOp>,
3047  PointwiseConverter<tosa::SubOp>,
3048  PointwiseConverter<tosa::MulOp>,
3049  PointwiseConverter<tosa::IntDivOp>,
3050  PointwiseConverter<tosa::NegateOp>,
3051  PointwiseConverter<tosa::PowOp>,
3052  PointwiseConverter<tosa::ReciprocalOp>,
3053  PointwiseConverter<tosa::RsqrtOp>,
3054  PointwiseConverter<tosa::LogOp>,
3055  PointwiseConverter<tosa::ExpOp>,
3056  PointwiseConverter<tosa::AbsOp>,
3057  PointwiseConverter<tosa::SinOp>,
3058  PointwiseConverter<tosa::CosOp>,
3059  PointwiseConverter<tosa::TanhOp>,
3060  PointwiseConverter<tosa::ErfOp>,
3061  PointwiseConverter<tosa::BitwiseAndOp>,
3062  PointwiseConverter<tosa::BitwiseOrOp>,
3063  PointwiseConverter<tosa::BitwiseNotOp>,
3064  PointwiseConverter<tosa::BitwiseXorOp>,
3065  PointwiseConverter<tosa::LogicalAndOp>,
3066  PointwiseConverter<tosa::LogicalNotOp>,
3067  PointwiseConverter<tosa::LogicalOrOp>,
3068  PointwiseConverter<tosa::LogicalXorOp>,
3069  PointwiseConverter<tosa::CastOp>,
3070  PointwiseConverter<tosa::LogicalLeftShiftOp>,
3071  PointwiseConverter<tosa::LogicalRightShiftOp>,
3072  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3073  PointwiseConverter<tosa::ClzOp>,
3074  PointwiseConverter<tosa::SelectOp>,
3075  PointwiseConverter<tosa::GreaterOp>,
3076  PointwiseConverter<tosa::GreaterEqualOp>,
3077  PointwiseConverter<tosa::EqualOp>,
3078  PointwiseConverter<tosa::MaximumOp>,
3079  PointwiseConverter<tosa::MinimumOp>,
3080  PointwiseConverter<tosa::CeilOp>,
3081  PointwiseConverter<tosa::FloorOp>,
3082  PointwiseConverter<tosa::ClampOp>,
3083  PointwiseConverter<tosa::SigmoidOp>
3084  >(converter, patterns->getContext());
3085 
3086  patterns->add<
3087  IdentityNConverter<tosa::IdentityOp>,
3088  ReduceConverter<tosa::ReduceAllOp>,
3089  ReduceConverter<tosa::ReduceAnyOp>,
3090  ReduceConverter<tosa::ReduceMinOp>,
3091  ReduceConverter<tosa::ReduceMaxOp>,
3092  ReduceConverter<tosa::ReduceSumOp>,
3093  ReduceConverter<tosa::ReduceProductOp>,
3094  ArgMaxConverter,
3095  GatherConverter,
3096  RescaleConverter,
3097  ReverseConverter,
3098  RFFT2dConverter,
3099  FFT2dConverter,
3100  TableConverter,
3101  TileConverter>(patterns->getContext());
3102  // clang-format on
3103 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static bool operandsAndResultsRanked(Operation *operation)
ArrayRef< float > table
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:387
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:372
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:100
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:221
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:526
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:57
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322