MLIR  21.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 
34 #include <numeric>
35 #include <type_traits>
36 
37 using namespace mlir;
38 using namespace mlir::tosa;
39 
40 // Helper function to materialize the semantically correct compare and select
41 // operations given a binary operation with a specific NaN propagation mode.
42 //
43 // In the case of "PROPAGATE" semantics no compare and selection is required and
44 // this function does nothing.
45 //
46 // In the case of "IGNORE" semantics this function materializes a comparison of
47 // the current operands to the op which will return true for any NaN
48 // argument and then selects between the non-NaN operation argument and the
49 // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
50 // code:
51 //
52 // In the case that the op is operating on non floating point types we ignore
53 // the attribute completely, this is consistent with the TOSA spec which has
54 // the following wording: "This attribute is ignored by non floating-point
55 // types."
56 //
57 // binary<op>(lhs, rhs):
58 // result = op(lhs, rhs)
59 // if lhs == NaN return rhs
60 // if rhs == NaN return lhs
61 // return result
62 template <typename OpTy>
63 static Value
65  Value lhs, Value rhs, Value result) {
66  // NaN propagation has no meaning for non floating point types.
67  if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
68  return result;
69 
70  auto nanMode = op.getNanMode();
71  if (nanMode == "PROPAGATE")
72  return result;
73 
74  // Unordered comparison of NaN against itself will always return true.
75  Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
76  op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
77  Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
78  op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
79  Value rhsOrResult =
80  rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
81  return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
82  rhsOrResult);
83 }
84 
85 template <typename T>
86 static arith::ConstantOp
87 createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
88  OpBuilder &rewriter) {
89  auto castedN = static_cast<T>(zp);
90  return rewriter.create<arith::ConstantOp>(
91  op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
92 }
93 
95  Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
96  ConversionPatternRewriter &rewriter) {
97  Location loc = op->getLoc();
98  auto elementTy =
99  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
100 
101  // tosa::AbsOp
102  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
103  return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
104 
105  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
106  auto zero = rewriter.create<arith::ConstantOp>(
107  loc, rewriter.getZeroAttr(elementTy));
108  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
109  return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
110  }
111 
112  // tosa::AddOp
113  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
114  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
115 
116  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
117  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
118 
119  // tosa::SubOp
120  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
121  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
122 
123  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
124  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
125 
126  // tosa::IntDivOp
127  if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
128  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
129 
130  // tosa::ReciprocalOp
131  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
132  auto one =
133  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
134  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
135  }
136 
137  // tosa::MulOp
138  if (isa<tosa::MulOp>(op)) {
139  auto shiftVal = cast<tosa::MulOp>(op).getShift();
140  DenseElementsAttr shiftElem;
141  if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
142  (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
143  return nullptr;
144  }
145 
146  int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
147 
148  if (isa<FloatType>(elementTy)) {
149  if (shift != 0) {
150  (void)rewriter.notifyMatchFailure(op,
151  "Cannot have shift value for float");
152  return nullptr;
153  }
154  return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
155  }
156 
157  if (isa<IntegerType>(elementTy)) {
158  Value a = args[0];
159  Value b = args[1];
160 
161  if (shift > 0) {
162  auto shiftConst =
163  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
164  if (!a.getType().isInteger(32))
165  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
166 
167  if (!b.getType().isInteger(32))
168  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
169 
170  auto result = rewriter.create<tosa::ApplyScaleOp>(
171  loc, rewriter.getI32Type(), a, b, shiftConst,
172  rewriter.getStringAttr("SINGLE_ROUND"));
173 
174  if (elementTy.isInteger(32))
175  return result;
176 
177  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
178  }
179 
180  int aWidth = a.getType().getIntOrFloatBitWidth();
181  int bWidth = b.getType().getIntOrFloatBitWidth();
182  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
183 
184  if (aWidth < cWidth)
185  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
186  if (bWidth < cWidth)
187  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
188 
189  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
190  }
191  }
192 
193  // tosa::NegateOp
194  if (isa<tosa::NegateOp>(op)) {
195  auto negate = cast<tosa::NegateOp>(op);
196 
197  FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
198  if (failed(maybeInZp)) {
199  (void)rewriter.notifyMatchFailure(
200  op, "input1 zero point cannot be statically determined");
201  return nullptr;
202  }
203 
204  FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
205  if (failed(maybeOutZp)) {
206  (void)rewriter.notifyMatchFailure(
207  op, "output zero point cannot be statically determined");
208  return nullptr;
209  }
210 
211  int64_t inZp = *maybeInZp;
212  int64_t outZp = *maybeOutZp;
213 
214  if (isa<FloatType>(elementTy))
215  return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
216 
217  if (isa<IntegerType>(elementTy)) {
218  if (!inZp && !outZp) {
219  auto constant = rewriter.create<arith::ConstantOp>(
220  loc, IntegerAttr::get(elementTy, 0));
221  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
222  args[0]);
223  }
224 
225  // Compute the maximum value that can occur in the intermediate buffer.
226  const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
227  const int64_t zpAdd = inZp + outZp;
228  const int64_t maxValue =
229  APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
230  std::abs(zpAdd) + 1;
231 
232  // Convert that maximum value into the maximum bitwidth needed to
233  // represent it. We assume 48-bit numbers may be supported further in
234  // the pipeline.
235  int intermediateBitWidth = 64;
236  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
237  intermediateBitWidth = 16;
238  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
239  intermediateBitWidth = 32;
240  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
241  intermediateBitWidth = 48;
242  }
243 
244  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
245  Value zpAddValue = rewriter.create<arith::ConstantOp>(
246  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
247 
248  // The negation can be applied by doing:
249  // outputValue = inZp + outZp - inputValue
250  auto ext =
251  rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
252  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
253 
254  // Clamp to the negation range.
255  Value min = rewriter.create<arith::ConstantIntOp>(
256  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
257  intermediateType);
258  Value max = rewriter.create<arith::ConstantIntOp>(
259  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
260  intermediateType);
261  auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
262 
263  // Truncate to the final value.
264  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
265  }
266  }
267 
268  // tosa::BitwiseAndOp
269  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
270  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
271 
272  // tosa::BitwiseOrOp
273  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
274  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
275 
276  // tosa::BitwiseNotOp
277  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
278  auto allOnesAttr = rewriter.getIntegerAttr(
279  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
280  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
281  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
282  }
283 
284  // tosa::BitwiseXOrOp
285  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
286  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
287 
288  // tosa::LogicalLeftShiftOp
289  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
290  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
291 
292  // tosa::LogicalRightShiftOp
293  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
294  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
295 
296  // tosa::ArithmeticRightShiftOp
297  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
298  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
299  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
300  if (!round) {
301  return result;
302  }
303 
304  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
305  auto one =
306  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
307  auto zero =
308  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
309  auto i1one =
310  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
311 
312  // Checking that input2 != 0
313  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
314  loc, arith::CmpIPredicate::sgt, args[1], zero);
315 
316  // Checking for the last bit of input1 to be 1
317  auto subtract =
318  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
319  auto shifted =
320  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
321  ->getResults();
322  auto truncated =
323  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
324  auto isInputOdd =
325  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
326 
327  auto shouldRound = rewriter.create<arith::AndIOp>(
328  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
329  auto extended =
330  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
331  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
332  }
333 
334  // tosa::ClzOp
335  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
336  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
337  }
338 
339  // tosa::LogicalAnd
340  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
341  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
342 
343  // tosa::LogicalNot
344  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
345  auto one = rewriter.create<arith::ConstantOp>(
346  loc, rewriter.getIntegerAttr(elementTy, 1));
347  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
348  }
349 
350  // tosa::LogicalOr
351  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
352  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
353 
354  // tosa::LogicalXor
355  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
356  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
357 
358  // tosa::PowOp
359  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
360  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
361 
362  // tosa::RsqrtOp
363  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
364  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
365 
366  // tosa::LogOp
367  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
368  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
369 
370  // tosa::ExpOp
371  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
372  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
373 
374  // tosa::SinOp
375  if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
376  return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
377 
378  // tosa::CosOp
379  if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
380  return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
381 
382  // tosa::TanhOp
383  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
384  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
385 
386  // tosa::ErfOp
387  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
388  return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
389 
390  // tosa::GreaterOp
391  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
392  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
393  args[0], args[1]);
394 
395  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
396  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
397  args[0], args[1]);
398 
399  // tosa::GreaterEqualOp
400  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
401  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
402  args[0], args[1]);
403 
404  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
405  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
406  args[0], args[1]);
407 
408  // tosa::EqualOp
409  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
410  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
411  args[0], args[1]);
412 
413  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
414  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
415  args[0], args[1]);
416 
417  // tosa::SelectOp
418  if (isa<tosa::SelectOp>(op)) {
419  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
420  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
421  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
422  }
423 
424  // tosa::MaximumOp
425  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
426  auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
427  return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
428  rewriter, args[0], args[1], max);
429  }
430 
431  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
432  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
433  }
434 
435  // tosa::MinimumOp
436  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
437  auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
438  return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
439  rewriter, args[0], args[1], min);
440  }
441 
442  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
443  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
444  }
445 
446  // tosa::CeilOp
447  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
448  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
449 
450  // tosa::FloorOp
451  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
452  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
453 
454  // tosa::ClampOp
455  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
456  bool losesInfo = false;
457  APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
458  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
459  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
460  APFloat::rmNearestTiesToEven, &losesInfo);
461  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
462  APFloat::rmNearestTiesToEven, &losesInfo);
463  auto min = rewriter.create<arith::ConstantOp>(
464  loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
465  auto max = rewriter.create<arith::ConstantOp>(
466  loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
467  auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
468 
469  auto clampOp = llvm::cast<tosa::ClampOp>(op);
470  const auto nanMode = clampOp.getNanMode();
471 
472  // NaN propagation has no meaning for non floating point types.
473  if (!isa<FloatType>(elementTy))
474  return result;
475 
476  // In the case of "PROPAGATE" semantics no compare and selection is
477  // required.
478  if (nanMode == "PROPAGATE")
479  return result;
480 
481  // In the case of "IGNORE" semantics materialize a comparison
482  // of the current operand to the reduction which will return true for a NaN
483  // argument and then selects between the initial reduction value and the
484  // calculated result based on whether the argument is NaN or not. In pseudo
485  // code:
486  //
487  // reduce<op>(x, init):
488  // result = op(init, x)
489  // return init if x == NaN else result
490 
491  // Unordered comparison of NaN against itself will always return true.
492  Value isNaN = rewriter.create<arith::CmpFOp>(
493  op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
494  // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
495  // is NaN.
496  return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
497  }
498 
499  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
500  auto intTy = cast<IntegerType>(elementTy);
501  int64_t min =
502  cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
503  int64_t max =
504  cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
505 
506  int64_t minRepresentable = std::numeric_limits<int64_t>::min();
507  int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
508  if (intTy.isUnsignedInteger()) {
509  minRepresentable = 0;
510  if (intTy.getIntOrFloatBitWidth() <= 63) {
511  maxRepresentable =
512  (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
513  .getZExtValue();
514  }
515  } else if (intTy.getIntOrFloatBitWidth() <= 64) {
516  // Ensure that min & max fit into signed n-bit constants.
517  minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
518  .getSExtValue();
519  maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
520  .getSExtValue();
521  }
522  // Ensure that the bounds are representable as n-bit signed/unsigned
523  // integers.
524  min = std::max(min, minRepresentable);
525  max = std::max(max, minRepresentable);
526  min = std::min(min, maxRepresentable);
527  max = std::min(max, maxRepresentable);
528 
529  auto minVal = rewriter.create<arith::ConstantIntOp>(
530  loc, min, intTy.getIntOrFloatBitWidth());
531  auto maxVal = rewriter.create<arith::ConstantIntOp>(
532  loc, max, intTy.getIntOrFloatBitWidth());
533  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
534  intTy.isUnsignedInteger());
535  }
536 
537  // tosa::SigmoidOp
538  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
539  auto one =
540  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
541  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
542  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
543  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
544  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
545  }
546 
547  // tosa::CastOp
548  if (isa<tosa::CastOp>(op)) {
549  Type srcTy = elementTy;
550  Type dstTy = resultTypes.front();
551  if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
552  (void)rewriter.notifyMatchFailure(op, "unsupported type");
553  return nullptr;
554  }
555 
556  bool bitExtend =
558 
559  if (srcTy == dstTy)
560  return args.front();
561 
562  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
563  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
564  std::nullopt);
565 
566  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
567  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
568  std::nullopt);
569 
570  // 1-bit integers need to be treated as signless.
571  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
572  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
573  std::nullopt);
574 
575  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
576  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
577  std::nullopt);
578 
579  // Unsigned integers need an unrealized cast so that they can be passed
580  // to UIToFP.
581  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
582  auto unrealizedCast =
583  rewriter
584  .create<UnrealizedConversionCastOp>(
585  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
586  args[0])
587  .getResult(0);
588  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
589  unrealizedCast);
590  }
591 
592  // All other si-to-fp conversions should be handled by SIToFP.
593  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
594  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
595  std::nullopt);
596 
597  // Casting to boolean, floats need to only be checked as not-equal to zero.
598  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
599  Value zero = rewriter.create<arith::ConstantOp>(
600  loc, rewriter.getFloatAttr(srcTy, 0.0));
601  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
602  args.front(), zero);
603  }
604 
605  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
606  auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
607 
608  const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
609  // Check whether neither int min nor int max can be represented in the
610  // input floating-point type due to too short exponent range.
611  if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
612  APFloat::semanticsMaxExponent(fltSemantics)) {
613  // Use cmp + select to replace infinites by int min / int max. Other
614  // integral values can be represented in the integer space.
615  auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
616  auto posInf = rewriter.create<arith::ConstantOp>(
617  loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
618  APFloat::getInf(fltSemantics)));
619  auto negInf = rewriter.create<arith::ConstantOp>(
620  loc, rewriter.getFloatAttr(
621  getElementTypeOrSelf(srcTy),
622  APFloat::getInf(fltSemantics, /*Negative=*/true)));
623  auto overflow = rewriter.create<arith::CmpFOp>(
624  loc, arith::CmpFPredicate::UEQ, rounded, posInf);
625  auto underflow = rewriter.create<arith::CmpFOp>(
626  loc, arith::CmpFPredicate::UEQ, rounded, negInf);
627  auto intMin = rewriter.create<arith::ConstantOp>(
628  loc, rewriter.getIntegerAttr(
629  getElementTypeOrSelf(dstTy),
630  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
631  auto intMax = rewriter.create<arith::ConstantOp>(
632  loc, rewriter.getIntegerAttr(
633  getElementTypeOrSelf(dstTy),
634  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
635  auto maxClamped =
636  rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
637  return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
638  maxClamped);
639  }
640 
641  auto intMinFP = rewriter.create<arith::ConstantOp>(
642  loc, rewriter.getFloatAttr(
643  getElementTypeOrSelf(srcTy),
644  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
645  .getSExtValue()));
646 
647  // Check whether the mantissa has enough bits to represent int max.
648  if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
649  dstTy.getIntOrFloatBitWidth() - 1) {
650  // Int min can also be represented since it is a power of two and thus
651  // consists of a single leading bit. Therefore we can clamp the input
652  // in the floating-point domain.
653 
654  auto intMaxFP = rewriter.create<arith::ConstantOp>(
655  loc, rewriter.getFloatAttr(
656  getElementTypeOrSelf(srcTy),
657  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
658  .getSExtValue()));
659 
660  Value clamped =
661  clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
662  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
663  }
664 
665  // Due to earlier check we know exponant range is big enough to represent
666  // int min. We can therefore rely on int max + 1 being representable as
667  // well because it's just int min with a positive sign. So clamp the min
668  // value and compare against that to select the max int value if needed.
669  auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
670  loc, rewriter.getFloatAttr(
671  getElementTypeOrSelf(srcTy),
672  static_cast<double>(
673  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
674  .getSExtValue()) +
675  1.0f));
676 
677  auto intMax = rewriter.create<arith::ConstantOp>(
678  loc, rewriter.getIntegerAttr(
679  getElementTypeOrSelf(dstTy),
680  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
681  auto minClampedFP =
682  rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
683  auto minClamped =
684  rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
685  auto overflow = rewriter.create<arith::CmpFOp>(
686  loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
687  return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
688  minClamped);
689  }
690 
691  // Casting to boolean, integers need to only be checked as not-equal to
692  // zero.
693  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
694  Value zero = rewriter.create<arith::ConstantIntOp>(
695  loc, 0, srcTy.getIntOrFloatBitWidth());
696  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
697  args.front(), zero);
698  }
699 
700  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
701  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
702  std::nullopt);
703 
704  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
705  return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
706  }
707  }
708 
709  (void)rewriter.notifyMatchFailure(
710  op, "unhandled op for linalg body calculation for elementwise op");
711  return nullptr;
712 }
713 
715 
716 // Emit an 'arith.constant' op for the given index if it has not been created
717 // yet, or return an existing constant. This will prevent an excessive creation
718 // of redundant constants, easing readability of emitted code for unit tests.
720  IndexPool &indexPool, int64_t index) {
721  auto [it, inserted] = indexPool.try_emplace(index);
722  if (inserted)
723  it->second =
724  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
725  return it->second;
726 }
727 
729  IndexPool &indexPool, Value tensor, int64_t index) {
730  auto indexValue = createIndex(rewriter, loc, indexPool, index);
731  return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
732 }
733 
735  IndexPool &indexPool, Value tensor,
736  int64_t index) {
737  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
738  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
739  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
740  if (shapedType.isDynamicDim(index))
741  return getTensorDim(rewriter, loc, indexPool, tensor, index);
742  return rewriter.getIndexAttr(shapedType.getDimSize(index));
743 }
744 
745 static bool operandsAndResultsRanked(Operation *operation) {
746  auto isRanked = [](Value value) {
747  return isa<RankedTensorType>(value.getType());
748  };
749  return llvm::all_of(operation->getOperands(), isRanked) &&
750  llvm::all_of(operation->getResults(), isRanked);
751 }
752 
753 // Compute the runtime dimension size for dimension 'dim' of the output by
754 // inspecting input 'operands', all of which are expected to have the same rank.
755 // This function returns a pair {targetSize, masterOperand}.
756 //
757 // The runtime size of the output dimension is returned either as a statically
758 // computed attribute or as a runtime SSA value.
759 //
760 // If the target size was inferred directly from one dominating operand, that
761 // operand is returned in 'masterOperand'. If the target size is inferred from
762 // multiple operands, 'masterOperand' is set to nullptr.
763 static std::pair<OpFoldResult, Value>
765  ValueRange operands, int64_t dim) {
766  // If any input operand contains a static size greater than 1 for this
767  // dimension, that is the target size. An occurrence of an additional static
768  // dimension greater than 1 with a different value is undefined behavior.
769  for (auto operand : operands) {
770  auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
771  if (!ShapedType::isDynamic(size) && size > 1)
772  return {rewriter.getIndexAttr(size), operand};
773  }
774 
775  // Filter operands with dynamic dimension
776  auto operandsWithDynamicDim =
777  llvm::filter_to_vector(operands, [&](Value operand) {
778  return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
779  });
780 
781  // If no operand has a dynamic dimension, it means all sizes were 1
782  if (operandsWithDynamicDim.empty())
783  return {rewriter.getIndexAttr(1), operands.front()};
784 
785  // Emit code that computes the runtime size for this dimension. If there is
786  // only one operand with a dynamic dimension, it is considered the master
787  // operand that determines the runtime size of the output dimension.
788  auto targetSize =
789  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
790  if (operandsWithDynamicDim.size() == 1)
791  return {targetSize, operandsWithDynamicDim[0]};
792 
793  // Calculate maximum size among all dynamic dimensions
794  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
795  auto nextSize =
796  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
797  targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
798  }
799  return {targetSize, nullptr};
800 }
801 
802 // Compute the runtime output size for all dimensions. This function returns
803 // a pair {targetShape, masterOperands}.
804 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
806  IndexPool &indexPool, ValueRange operands) {
807  assert(!operands.empty());
808  auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
809  SmallVector<OpFoldResult> targetShape;
810  SmallVector<Value> masterOperands;
811  for (auto dim : llvm::seq<int64_t>(0, rank)) {
812  auto [targetSize, masterOperand] =
813  computeTargetSize(rewriter, loc, indexPool, operands, dim);
814  targetShape.push_back(targetSize);
815  masterOperands.push_back(masterOperand);
816  }
817  return {targetShape, masterOperands};
818 }
819 
821  IndexPool &indexPool, Value operand,
822  int64_t dim, OpFoldResult targetSize,
823  Value masterOperand) {
824  // Nothing to do if this is a static dimension
825  auto rankedTensorType = cast<RankedTensorType>(operand.getType());
826  if (!rankedTensorType.isDynamicDim(dim))
827  return operand;
828 
829  // If the target size for this dimension was directly inferred by only taking
830  // this operand into account, there is no need to broadcast. This is an
831  // optimization that will prevent redundant control flow, and constitutes the
832  // main motivation for tracking "master operands".
833  if (operand == masterOperand)
834  return operand;
835 
836  // Affine maps for 'linalg.generic' op
837  auto rank = rankedTensorType.getRank();
838  SmallVector<AffineExpr> affineExprs;
839  for (auto index : llvm::seq<int64_t>(0, rank)) {
840  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
841  : rewriter.getAffineDimExpr(index);
842  affineExprs.push_back(affineExpr);
843  }
844  auto broadcastAffineMap =
845  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
846  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
847  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
848 
849  // Check if broadcast is necessary
850  auto one = createIndex(rewriter, loc, indexPool, 1);
851  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
852  auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
853  loc, arith::CmpIPredicate::eq, runtimeSize, one);
854 
855  // Emit 'then' region of 'scf.if'
856  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
857  // It is not safe to cache constants across regions.
858  // New constants could potentially violate dominance requirements.
859  IndexPool localPool;
860 
861  // Emit 'tensor.empty' op
862  SmallVector<OpFoldResult> outputTensorShape;
863  for (auto index : llvm::seq<int64_t>(0, rank)) {
864  auto size = index == dim ? targetSize
865  : getOrFoldTensorDim(rewriter, loc, localPool,
866  operand, index);
867  outputTensorShape.push_back(size);
868  }
869  Value outputTensor = opBuilder.create<tensor::EmptyOp>(
870  loc, outputTensorShape, rankedTensorType.getElementType());
871 
872  // Emit 'linalg.generic' op
873  auto resultTensor =
874  opBuilder
875  .create<linalg::GenericOp>(
876  loc, outputTensor.getType(), operand, outputTensor, affineMaps,
878  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
879  // Emit 'linalg.yield' op
880  opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
881  })
882  .getResult(0);
883 
884  // Cast to original operand type if necessary
885  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
886  loc, operand.getType(), resultTensor);
887 
888  // Emit 'scf.yield' op
889  opBuilder.create<scf::YieldOp>(loc, castResultTensor);
890  };
891 
892  // Emit 'else' region of 'scf.if'
893  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
894  opBuilder.create<scf::YieldOp>(loc, operand);
895  };
896 
897  // Emit 'scf.if' op
898  auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
899  emitThenRegion, emitElseRegion);
900  return ifOp.getResult(0);
901 }
902 
904  IndexPool &indexPool, Value operand,
905  ArrayRef<OpFoldResult> targetShape,
906  ArrayRef<Value> masterOperands) {
907  int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
908  assert((int64_t)targetShape.size() == rank);
909  assert((int64_t)masterOperands.size() == rank);
910  for (auto index : llvm::seq<int64_t>(0, rank))
911  operand =
912  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
913  targetShape[index], masterOperands[index]);
914  return operand;
915 }
916 
917 static SmallVector<Value>
919  IndexPool &indexPool, ValueRange operands,
920  ArrayRef<OpFoldResult> targetShape,
921  ArrayRef<Value> masterOperands) {
922  // No need to broadcast for unary operations
923  if (operands.size() == 1)
924  return operands;
925 
926  // Broadcast dynamic dimensions operand by operand
927  return llvm::map_to_vector(operands, [&](Value operand) {
928  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
929  targetShape, masterOperands);
930  });
931 }
932 
933 static LogicalResult
935  Operation *operation, ValueRange operands,
936  ArrayRef<OpFoldResult> targetShape,
937  const TypeConverter &converter) {
938  // Generate output tensor
939  auto resultType = cast_or_null<RankedTensorType>(
940  converter.convertType(operation->getResultTypes().front()));
941  if (!resultType) {
942  return rewriter.notifyMatchFailure(operation, "failed to convert type");
943  }
944  Value outputTensor = rewriter.create<tensor::EmptyOp>(
945  loc, targetShape, resultType.getElementType());
946 
947  // Create affine maps. Input affine maps broadcast static dimensions of size
948  // 1. The output affine map is an identity map.
949  //
950  auto rank = resultType.getRank();
951  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
952  auto shape = cast<ShapedType>(operand.getType()).getShape();
953  SmallVector<AffineExpr> affineExprs;
954  for (auto it : llvm::enumerate(shape)) {
955  // Prefer producting identity maps whenever possible (i.e. no broadcasting
956  // needed) because some transforms (like reshape folding)
957  // do not support affine constant exprs.
958  bool requiresBroadcast =
959  (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
960  auto affineExpr = requiresBroadcast
961  ? rewriter.getAffineConstantExpr(0)
962  : rewriter.getAffineDimExpr(it.index());
963  affineExprs.push_back(affineExpr);
964  }
965  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
966  });
967  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
968 
969  // Emit 'linalg.generic' op
970  bool encounteredError = false;
971  auto linalgOp = rewriter.create<linalg::GenericOp>(
972  loc, outputTensor.getType(), operands, outputTensor, affineMaps,
974  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
976  operation, blockArgs.take_front(operation->getNumOperands()),
977  {resultType.getElementType()}, rewriter);
978  if (!opResult) {
979  encounteredError = true;
980  return;
981  }
982  opBuilder.create<linalg::YieldOp>(loc, opResult);
983  });
984  if (encounteredError)
985  return rewriter.notifyMatchFailure(
986  operation, "unable to create linalg.generic body for elementwise op");
987 
988  // Cast 'linalg.generic' result into original result type if needed
989  auto castResult = rewriter.createOrFold<tensor::CastOp>(
990  loc, resultType, linalgOp->getResult(0));
991  rewriter.replaceOp(operation, castResult);
992  return success();
993 }
994 
996  ValueRange operands) {
997  // Shift cannot broadcast
998  if (isa<tosa::MulOp>(operation))
999  return operands.take_front(2);
1000  // Input1_zp and output_zp cannot broadcast
1001  if (isa<tosa::NegateOp>(operation))
1002  return operands.take_front(1);
1003  return operands;
1004 }
1005 
1006 static LogicalResult
1008  ConversionPatternRewriter &rewriter,
1009  const TypeConverter &converter) {
1010 
1011  // Collect op properties
1012  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1013  assert(operation->getNumOperands() >= 1 &&
1014  "elementwise op expects at least 1 operand");
1015  if (!operandsAndResultsRanked(operation))
1016  return rewriter.notifyMatchFailure(operation,
1017  "Unranked tensors not supported");
1018 
1019  // Lower operation
1020  IndexPool indexPool;
1021  auto loc = operation->getLoc();
1022  auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1023  auto [targetShape, masterOperands] =
1024  computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
1025  auto broadcastOperands =
1026  broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
1027  targetShape, masterOperands);
1028  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
1029  targetShape, converter);
1030 }
1031 
1032 // Returns the constant initial value for a given reduction operation. The
1033 // attribute type varies depending on the element type required.
1034 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1035  PatternRewriter &rewriter) {
1036  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1037  return rewriter.getFloatAttr(elementTy, 0.0);
1038 
1039  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1040  return rewriter.getIntegerAttr(elementTy, 0);
1041 
1042  if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1043  return rewriter.getFloatAttr(elementTy, 1.0);
1044 
1045  if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1046  return rewriter.getIntegerAttr(elementTy, 1);
1047 
1048  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1049  return rewriter.getFloatAttr(
1050  elementTy, APFloat::getLargest(
1051  cast<FloatType>(elementTy).getFloatSemantics(), false));
1052 
1053  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1054  return rewriter.getIntegerAttr(
1055  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
1056 
1057  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1058  return rewriter.getFloatAttr(
1059  elementTy, APFloat::getLargest(
1060  cast<FloatType>(elementTy).getFloatSemantics(), true));
1061 
1062  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1063  return rewriter.getIntegerAttr(
1064  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1065 
1066  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1067  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1068 
1069  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1070  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1071 
1072  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1073  return rewriter.getFloatAttr(
1074  elementTy, APFloat::getLargest(
1075  cast<FloatType>(elementTy).getFloatSemantics(), true));
1076 
1077  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1078  return rewriter.getIntegerAttr(
1079  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1080 
1081  return {};
1082 }
1083 
1084 // Creates the body calculation for a reduction. The operations vary depending
1085 // on the input type.
1087  ValueRange args,
1088  Type elementTy,
1089  PatternRewriter &rewriter) {
1090  Location loc = op->getLoc();
1091  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1092  return rewriter.create<arith::AddFOp>(loc, args);
1093  }
1094 
1095  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1096  return rewriter.create<arith::AddIOp>(loc, args);
1097  }
1098 
1099  if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1100  return rewriter.create<arith::MulFOp>(loc, args);
1101  }
1102 
1103  if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1104  return rewriter.create<arith::MulIOp>(loc, args);
1105  }
1106 
1107  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1108  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1109  }
1110 
1111  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1112  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1113  }
1114 
1115  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1116  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1117  }
1118 
1119  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1120  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1121  }
1122 
1123  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1124  return rewriter.create<arith::AndIOp>(loc, args);
1125 
1126  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1127  return rewriter.create<arith::OrIOp>(loc, args);
1128 
1129  return {};
1130 }
1131 
1132 // Performs the match and rewrite for reduction operations. This includes
1133 // declaring a correctly sized initial value, and the linalg.generic operation
1134 // that reduces across the specified axis.
1135 template <typename OpTy>
1136 static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1137  PatternRewriter &rewriter) {
1138  auto loc = op->getLoc();
1139  auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1140  auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1141  if (!inputTy || !resultTy)
1142  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1143 
1144  auto elementTy = resultTy.getElementType();
1145  Value input = op->getOperand(0);
1146 
1147  SmallVector<int64_t> reduceShape;
1148  SmallVector<Value> dynDims;
1149  for (unsigned i = 0; i < inputTy.getRank(); i++) {
1150  if (axis != i) {
1151  reduceShape.push_back(inputTy.getDimSize(i));
1152  if (inputTy.isDynamicDim(i))
1153  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1154  }
1155  }
1156 
1157  SmallVector<Value> inputs, outputs;
1158  inputs.push_back(input);
1159 
1160  // First fill the output buffer with the init value.
1161  auto emptyTensor =
1162  rewriter
1163  .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1164  dynDims)
1165  .getResult();
1166 
1167  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1168  if (!fillValueAttr)
1169  return rewriter.notifyMatchFailure(
1170  op, "No initial value found for reduction operation");
1171 
1172  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1173  auto filledTensor = rewriter
1174  .create<linalg::FillOp>(loc, ValueRange{fillValue},
1175  ValueRange{emptyTensor})
1176  .result();
1177  outputs.push_back(filledTensor);
1178 
1179  bool isNanIgnoreMode = false;
1180  if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1181  std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1182  // NaN propagation has no meaning for non floating point types.
1183  if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
1184  isNanIgnoreMode = true;
1185  // Because the TOSA spec requires the result be NaN iff all elements in
1186  // the reduction are NaN we can't simply perform a compare and select.
1187  // Additionally we have to keep track of whether we've seen any non-NaN
1188  // values and then do a final select based on this predicate.
1189  auto trueAttr = rewriter.getBoolAttr(true);
1190  auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
1191  auto emptyBoolTensor =
1192  rewriter
1193  .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
1194  dynDims)
1195  .getResult();
1196  auto allResultsNaNTensor =
1197  rewriter
1198  .create<linalg::FillOp>(loc, ValueRange{trueValue},
1199  ValueRange{emptyBoolTensor})
1200  .result();
1201  // Note that because the linalg::ReduceOp has two variadic arguments
1202  // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1203  // need to have the same number of inputs and outputs.
1204  //
1205  // The second input isn't actually used anywhere since the value used to
1206  // update the NaN flag is calculated inside the body of the reduction and
1207  // then used to update an out value.
1208  // In order to satisfy type constraints we just pass another copy of the
1209  // input here.
1210  inputs.push_back(input);
1211  outputs.push_back(allResultsNaNTensor);
1212  }
1213  }
1214 
1215  bool didEncounterError = false;
1216  linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
1217  loc, inputs, outputs, axis,
1218  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1219  std::array<Value, 2> binaryArgs{
1220  blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1222  op, binaryArgs, elementTy, rewriter);
1223  if (result)
1224  didEncounterError = true;
1225 
1226  SmallVector<Value> resultsToYield;
1227  if (isNanIgnoreMode) {
1228  auto inputValue = blockArgs[0];
1229  auto initialValue = blockArgs[2];
1230  auto oldAllResultsNanFlagValue = blockArgs[3];
1231 
1232  // Unordered comparison of NaN against itself will always return true.
1233  Value isNaN = nestedBuilder.create<arith::CmpFOp>(
1234  op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1235  // If we've encountered a NaN, take the non-NaN value.
1236  auto selectOp = nestedBuilder.create<arith::SelectOp>(
1237  op->getLoc(), isNaN, initialValue, result);
1238  // Update the flag which keeps track of whether we have seen a non-NaN
1239  // value.
1240  auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
1241  op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1242  resultsToYield.push_back(selectOp);
1243  resultsToYield.push_back(newAllResultsNanFlagValue);
1244  } else {
1245  resultsToYield.push_back(result);
1246  }
1247  nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
1248  });
1249 
1250  if (!didEncounterError)
1251  return rewriter.notifyMatchFailure(
1252  op, "unable to create linalg.generic body for reduce op");
1253 
1254  if (isNanIgnoreMode) {
1255  // Materialize a check to see whether we encountered any non-NaN values, if
1256  // we didn't we need to select a tensor of NaNs since the result will just
1257  // be the initial identity value propagated through all the compares and
1258  // selects inside the reduction.
1259 
1260  // Create a tensor full of NaNs.
1261  auto nanValueAttr = rewriter.getFloatAttr(
1262  elementTy,
1263  APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1264  auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
1265  auto emptyNanTensor =
1266  rewriter
1267  .create<tensor::EmptyOp>(loc, reduceShape,
1268  resultTy.getElementType(), dynDims)
1269  .getResult();
1270  auto nanFilledTensor =
1271  rewriter
1272  .create<linalg::FillOp>(loc, ValueRange{nanValue},
1273  ValueRange{emptyNanTensor})
1274  .result();
1275 
1276  // Create an empty tensor, non need to fill this since it will be
1277  // overwritten by the select.
1278  auto finalEmptyTensor =
1279  rewriter
1280  .create<tensor::EmptyOp>(loc, reduceShape,
1281  resultTy.getElementType(), dynDims)
1282  .getResult();
1283 
1284  // Do a selection between the tensors akin to:
1285  // result = NaN if "all results NaN" else result.
1286  SmallVector<Value> ins, outs;
1287  ins.push_back(linalgOp->getOpResult(1));
1288  ins.push_back(nanFilledTensor);
1289  ins.push_back(linalgOp->getResult(0));
1290  outs.push_back(finalEmptyTensor);
1291  auto linalgSelect =
1292  rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
1293  linalgOp = linalgSelect;
1294  }
1295 
1296  SmallVector<ReassociationExprs, 4> reassociationMap;
1297  uint64_t expandInputRank =
1298  cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
1299  reassociationMap.resize(expandInputRank);
1300 
1301  for (uint64_t i = 0; i < expandInputRank; i++) {
1302  int32_t dimToPush = i > axis ? i + 1 : i;
1303  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1304  }
1305 
1306  if (expandInputRank != 0) {
1307  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1308  reassociationMap[expandedDim].push_back(
1309  rewriter.getAffineDimExpr(expandedDim + 1));
1310  }
1311 
1312  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1313  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1314  // not have access to such information. This matters when handling dynamically
1315  // sized tensors.
1316  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1317  op, resultTy, linalgOp->getResults()[0], reassociationMap);
1318  return success();
1319 }
1320 
1321 namespace {
1322 
1323 template <typename SrcOp>
1324 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1325 public:
1328 
1329  LogicalResult
1330  matchAndRewrite(SrcOp op, OpAdaptor operands,
1331  ConversionPatternRewriter &rewriter) const final {
1333  op, operands.getOperands(), rewriter, *this->getTypeConverter());
1334  }
1335 };
1336 
1337 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1338 public:
1340 
1341  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1342  PatternRewriter &rewriter) const final {
1343  auto loc = op.getLoc();
1344  auto input = op.getInput();
1345  auto inputTy = cast<ShapedType>(op.getInput().getType());
1346  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1347  unsigned rank = inputTy.getRank();
1348 
1349  // This is an illegal configuration. terminate and log an error
1350  if (op.getRoundingMode() == "INEXACT_ROUND")
1351  return rewriter.notifyMatchFailure(
1352  op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1353  "currently supported");
1354  if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
1355  return rewriter.notifyMatchFailure(
1356  op, "tosa.rescale requires scale32 for double_round to be true");
1357 
1358  if (!isa<IntegerType>(inputTy.getElementType()))
1359  return rewriter.notifyMatchFailure(op, "only support integer type");
1360 
1361  SmallVector<Value> dynDims;
1362  for (int i = 0; i < outputTy.getRank(); i++) {
1363  if (outputTy.isDynamicDim(i)) {
1364  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1365  }
1366  }
1367 
1368  // The shift and multiplier values.
1369  DenseElementsAttr shiftElems;
1370  if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1371  return rewriter.notifyMatchFailure(
1372  op, "tosa.rescale requires constant shift input values");
1373 
1374  DenseElementsAttr multiplierElems;
1375  if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1376  return rewriter.notifyMatchFailure(
1377  op, "tosa.rescale requires constant multiplier input values");
1378 
1379  llvm::SmallVector<int8_t> shiftValues =
1380  llvm::to_vector(shiftElems.getValues<int8_t>());
1381  // explicit cast is required here
1382  llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1383  llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1384  [](IntegerAttr attr) -> int32_t {
1385  return static_cast<int32_t>(attr.getInt());
1386  }));
1387 
1388  // If we shift by more than the bitwidth, this just sets to 0.
1389  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1390  if (shiftValues[i] > 63) {
1391  shiftValues[i] = 0;
1392  multiplierValues[i] = 0;
1393  }
1394  }
1395 
1396  // Double round only occurs if shift is greater than 31, check that this
1397  // is ever true.
1398 
1399  bool doubleRound =
1400  op.getRoundingMode() == "DOUBLE_ROUND" &&
1401  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1402  StringAttr roundingMode = doubleRound
1403  ? rewriter.getStringAttr("DOUBLE_ROUND")
1404  : rewriter.getStringAttr("SINGLE_ROUND");
1405 
1406  SmallVector<AffineMap> indexingMaps = {
1407  rewriter.getMultiDimIdentityMap(rank)};
1408  SmallVector<Value, 4> genericInputs = {input};
1409 
1410  // If we are rescaling per-channel then we need to store the multiplier
1411  // values in a buffer.
1412  Value multiplierConstant;
1413  int64_t multiplierArg = 0;
1414  if (multiplierValues.size() == 1) {
1415  multiplierConstant = rewriter.create<arith::ConstantOp>(
1416  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1417  } else {
1418  SmallVector<AffineExpr, 2> multiplierExprs{
1419  rewriter.getAffineDimExpr(rank - 1)};
1420  auto multiplierType =
1421  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1422  rewriter.getI32Type());
1423  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1424  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1425 
1426  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1427  /*symbolCount=*/0, multiplierExprs,
1428  rewriter.getContext()));
1429 
1430  multiplierArg = indexingMaps.size() - 1;
1431  }
1432 
1433  // If we are rescaling per-channel then we need to store the shift
1434  // values in a buffer.
1435  Value shiftConstant;
1436  int64_t shiftArg = 0;
1437  if (shiftValues.size() == 1) {
1438  shiftConstant = rewriter.create<arith::ConstantOp>(
1439  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1440  } else {
1441  SmallVector<AffineExpr, 2> shiftExprs = {
1442  rewriter.getAffineDimExpr(rank - 1)};
1443  auto shiftType =
1444  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1445  rewriter.getIntegerType(8));
1446  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1447  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1448  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1449  /*symbolCount=*/0, shiftExprs,
1450  rewriter.getContext()));
1451  shiftArg = indexingMaps.size() - 1;
1452  }
1453 
1454  // Indexing maps for output values.
1455  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1456 
1457  // Construct the indexing maps needed for linalg.generic ops.
1458  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1459  loc, outputTy.getShape(), outputTy.getElementType(),
1460  ArrayRef<Value>({dynDims}));
1461 
1462  auto linalgOp = rewriter.create<linalg::GenericOp>(
1463  loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1464  getNParallelLoopsAttrs(rank),
1465  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1466  ValueRange blockArgs) {
1467  Value value = blockArgs[0];
1468  Type valueTy = value.getType();
1469 
1470  // For now we do all of our math in 64-bit. This is not optimal but
1471  // should be correct for now, consider computing correct bit depth
1472  // later.
1473  int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1474 
1475  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1476  if (failed(maybeIZp)) {
1477  (void)rewriter.notifyMatchFailure(
1478  op, "input zero point cannot be statically determined");
1479  return;
1480  }
1481 
1482  auto inputZp = createConstOpFromZpVal<int32_t>(
1483  op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
1484  nestedBuilder);
1485 
1486  FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1487  if (failed(maybeOZp)) {
1488  (void)rewriter.notifyMatchFailure(
1489  op, "output zero point cannot be statically determined");
1490  return;
1491  };
1492 
1493  auto outputZp = createConstOpFromZpVal<int32_t>(
1494  op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
1495 
1496  Value multiplier = multiplierConstant ? multiplierConstant
1497  : blockArgs[multiplierArg];
1498  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1499 
1500  if (valueTy.getIntOrFloatBitWidth() < 32) {
1501  if (op.getInputUnsigned()) {
1502  value = nestedBuilder.create<arith::ExtUIOp>(
1503  nestedLoc, nestedBuilder.getI32Type(), value);
1504  } else {
1505  value = nestedBuilder.create<arith::ExtSIOp>(
1506  nestedLoc, nestedBuilder.getI32Type(), value);
1507  }
1508  }
1509 
1510  value =
1511  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1512 
1513  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1514  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1515  roundingMode);
1516 
1517  // Move to the new zero-point.
1518  value =
1519  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1520 
1521  // Saturate to the output size.
1522  IntegerType outIntType =
1523  cast<IntegerType>(blockArgs.back().getType());
1524  unsigned outBitWidth = outIntType.getWidth();
1525 
1526  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1527  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1528 
1529  // Unsigned integers have a difference output value.
1530  if (op.getOutputUnsigned()) {
1531  intMin = 0;
1532  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1533  }
1534 
1535  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1536  loc, nestedBuilder.getI32IntegerAttr(intMin));
1537  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1538  loc, nestedBuilder.getI32IntegerAttr(intMax));
1539 
1540  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1541  nestedBuilder, /*isUnsigned=*/false);
1542 
1543  if (outIntType.getWidth() < 32) {
1544  value = nestedBuilder.create<arith::TruncIOp>(
1545  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1546  value);
1547  }
1548 
1549  nestedBuilder.create<linalg::YieldOp>(loc, value);
1550  });
1551 
1552  rewriter.replaceOp(op, linalgOp->getResults());
1553  return success();
1554  }
1555 };
1556 
1557 // Handle the resize case where the input is a 1x1 image. This case
1558 // can entirely avoiding having extract operations which target much
1559 // more difficult to optimize away.
1560 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1561 public:
1563 
1564  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1565  PatternRewriter &rewriter) const final {
1566  Location loc = op.getLoc();
1567  ImplicitLocOpBuilder builder(loc, rewriter);
1568  auto input = op.getInput();
1569  auto inputTy = cast<RankedTensorType>(input.getType());
1570  auto resultTy = cast<RankedTensorType>(op.getType());
1571  const bool isBilinear = op.getMode() == "BILINEAR";
1572 
1573  auto inputH = inputTy.getDimSize(1);
1574  auto inputW = inputTy.getDimSize(2);
1575  auto outputH = resultTy.getDimSize(1);
1576  auto outputW = resultTy.getDimSize(2);
1577 
1578  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1579  return rewriter.notifyMatchFailure(
1580  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1581 
1582  // TODO(suderman): These string values should be declared the TOSA dialect.
1583  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1584  return rewriter.notifyMatchFailure(
1585  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1586 
1587  if (inputTy == resultTy) {
1588  rewriter.replaceOp(op, input);
1589  return success();
1590  }
1591 
1592  SmallVector<int64_t> scale;
1593  if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
1594  return failure();
1595  }
1596 
1597  // Collapse the unit width and height away.
1598  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1599  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1600  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1601  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1602  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1603 
1604  auto collapseTy =
1605  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1606  inputTy.getElementType());
1607  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1608  reassociationMap);
1609 
1610  // Get any dynamic shapes that appear in the input format.
1611  llvm::SmallVector<Value> outputDynSize;
1612  if (inputTy.isDynamicDim(0))
1613  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1614  if (inputTy.isDynamicDim(3))
1615  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1616 
1617  // Generate the elementwise operation for casting scaling the input value.
1618  auto genericTy = collapseTy.clone(resultTy.getElementType());
1619  Value empty = builder.create<tensor::EmptyOp>(
1620  genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1621  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1622  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1623  utils::IteratorType::parallel);
1624 
1625  auto generic = builder.create<linalg::GenericOp>(
1626  genericTy, ValueRange{collapse}, ValueRange{empty},
1627  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1628  [=](OpBuilder &b, Location loc, ValueRange args) {
1629  Value value = args[0];
1630  // This is the quantized case.
1631  if (inputTy.getElementType() != resultTy.getElementType()) {
1632  value =
1633  b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1634 
1635  if (isBilinear && scale[0] != 0) {
1636  Value scaleY = b.create<arith::ConstantOp>(
1637  loc, b.getI32IntegerAttr(scale[0]));
1638  value = b.create<arith::MulIOp>(loc, value, scaleY);
1639  }
1640 
1641  if (isBilinear && scale[2] != 0) {
1642  Value scaleX = b.create<arith::ConstantOp>(
1643  loc, b.getI32IntegerAttr(scale[2]));
1644  value = b.create<arith::MulIOp>(loc, value, scaleX);
1645  }
1646  }
1647 
1648  b.create<linalg::YieldOp>(loc, value);
1649  });
1650 
1651  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1652  op, resultTy, generic.getResults()[0], reassociationMap);
1653  return success();
1654  }
1655 };
1656 
1657 // TOSA resize with width or height of 1 may be broadcasted to a wider
1658 // dimension. This is done by materializing a new tosa.resize without
1659 // the broadcasting behavior, and an explicit broadcast afterwards.
1660 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1661 public:
1663 
1664  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1665  PatternRewriter &rewriter) const final {
1666  Location loc = op.getLoc();
1667  ImplicitLocOpBuilder builder(loc, rewriter);
1668  auto input = op.getInput();
1669  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1670  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1671 
1672  if (!inputTy || !resultTy)
1673  return rewriter.notifyMatchFailure(op,
1674  "requires ranked input/output types");
1675 
1676  auto batch = inputTy.getDimSize(0);
1677  auto channels = inputTy.getDimSize(3);
1678  auto inputH = inputTy.getDimSize(1);
1679  auto inputW = inputTy.getDimSize(2);
1680  auto outputH = resultTy.getDimSize(1);
1681  auto outputW = resultTy.getDimSize(2);
1682 
1683  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1684  return rewriter.notifyMatchFailure(
1685  op, "tosa.resize has no broadcasting behavior");
1686 
1687  // For any dimension that is broadcastable we generate a width of 1
1688  // on the output.
1689  llvm::SmallVector<int64_t> resizeShape;
1690  resizeShape.push_back(batch);
1691  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1692  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1693  resizeShape.push_back(channels);
1694 
1695  auto resizeTy = resultTy.clone(resizeShape);
1696  auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
1697  op.getOffset(), op.getBorder(),
1698  op.getMode());
1699 
1700  // Collapse an unit result dims.
1701  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1702  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1703  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1704  if (inputH != 1)
1705  reassociationMap.push_back({});
1706  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1707  if (inputW != 1)
1708  reassociationMap.push_back({});
1709  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1710 
1711  llvm::SmallVector<int64_t> collapseShape = {batch};
1712  if (inputH != 1)
1713  collapseShape.push_back(outputH);
1714  if (inputW != 1)
1715  collapseShape.push_back(outputW);
1716  collapseShape.push_back(channels);
1717 
1718  auto collapseTy = resultTy.clone(collapseShape);
1719  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1720  reassociationMap);
1721 
1722  // Broadcast the collapsed shape to the output result.
1723  llvm::SmallVector<Value> outputDynSize;
1724  if (inputTy.isDynamicDim(0))
1725  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1726  if (inputTy.isDynamicDim(3))
1727  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1728 
1729  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1730  utils::IteratorType::parallel);
1731  Value empty = builder.create<tensor::EmptyOp>(
1732  resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1733 
1734  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1735  if (inputH != 1)
1736  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1737  if (inputW != 1)
1738  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1739  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1740 
1741  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1742  inputExprs, rewriter.getContext());
1743 
1744  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1745  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1746  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1747  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1748  [=](OpBuilder &b, Location loc, ValueRange args) {
1749  Value value = args[0];
1750  b.create<linalg::YieldOp>(loc, value);
1751  });
1752 
1753  return success();
1754  }
1755 };
1756 
1757 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1758 public:
1760 
1761  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1762  PatternRewriter &rewriter) const final {
1763  Location loc = op.getLoc();
1764  ImplicitLocOpBuilder b(loc, rewriter);
1765  auto input = op.getInput();
1766  auto inputTy = cast<ShapedType>(input.getType());
1767  auto resultTy = cast<ShapedType>(op.getType());
1768  auto resultETy = resultTy.getElementType();
1769 
1770  bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1771  auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1772 
1773  auto imageH = inputTy.getShape()[1];
1774  auto imageW = inputTy.getShape()[2];
1775 
1776  auto dynamicDimsOr =
1777  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1778  if (!dynamicDimsOr.has_value())
1779  return rewriter.notifyMatchFailure(
1780  op, "unable to get dynamic dimensions of tosa.resize");
1781 
1782  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1783  return rewriter.notifyMatchFailure(
1784  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1785 
1786  SmallVector<AffineMap, 2> affineMaps = {
1787  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1788  auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1789  *dynamicDimsOr);
1790  auto genericOp = b.create<linalg::GenericOp>(
1791  resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1792  getNParallelLoopsAttrs(resultTy.getRank()));
1793  Value resize = genericOp.getResult(0);
1794 
1795  {
1796  OpBuilder::InsertionGuard regionGuard(b);
1797  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1798  TypeRange({resultETy}), loc);
1799  Value batch = b.create<linalg::IndexOp>(0);
1800  Value y = b.create<linalg::IndexOp>(1);
1801  Value x = b.create<linalg::IndexOp>(2);
1802  Value channel = b.create<linalg::IndexOp>(3);
1803 
1804  Value zeroI32 =
1805  b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1806  Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1807  Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1808  Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1809 
1810  Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1811  Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1812 
1813  SmallVector<int64_t> scale, offset, border;
1814  if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
1815  !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
1816  !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
1817  return rewriter.notifyMatchFailure(
1818  op, "tosa.resize scale/offset/border should have compile time "
1819  "constant values.");
1820  }
1821 
1822  Value yScaleN, yScaleD, xScaleN, xScaleD;
1823  yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1824  yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1825  xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1826  xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1827 
1828  Value yOffset, xOffset, yBorder, xBorder;
1829  yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1830  xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1831  yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1832  xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1833 
1834  // Compute the ix and dx values for both the X and Y dimensions.
1835  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1836  Value scaleN, Value scaleD, Value offset,
1837  int size, ImplicitLocOpBuilder &b) {
1838  if (size == 1) {
1839  index = zeroI32;
1840  delta = zeroFp;
1841  return;
1842  }
1843  // x = x * scale_d + offset;
1844  // ix = floor(x / scale_n)
1845  Value val = b.create<arith::MulIOp>(in, scaleD);
1846  val = b.create<arith::AddIOp>(val, offset);
1847  index = b.create<arith::FloorDivSIOp>(val, scaleN);
1848 
1849  // rx = x % scale_n
1850  // dx = rx / scale_n
1851  Value r = b.create<arith::RemSIOp>(val, scaleN);
1852  Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1853  Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1854  delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1855  };
1856 
1857  // Compute the ix and dx values for the X and Y dimensions - int case.
1858  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1859  Value scaleN, Value scaleD, Value offset,
1860  int size, ImplicitLocOpBuilder &b) {
1861  if (size == 1) {
1862  index = zeroI32;
1863  delta = zeroI32;
1864  return;
1865  }
1866  // x = x * scale_d + offset;
1867  // ix = floor(x / scale_n)
1868  // dx = x - ix * scale_n;
1869  Value val = b.create<arith::MulIOp>(in, scaleD);
1870  val = b.create<arith::AddIOp>(val, offset);
1871  index = b.create<arith::DivSIOp>(val, scaleN);
1872  delta = b.create<arith::MulIOp>(index, scaleN);
1873  delta = b.create<arith::SubIOp>(val, delta);
1874  };
1875 
1876  Value ix, iy, dx, dy;
1877  if (floatingPointMode) {
1878  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1879  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1880  } else {
1881  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1882  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1883  }
1884 
1885  if (op.getMode() == "NEAREST_NEIGHBOR") {
1886  auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1887 
1888  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1889  Value max, int size,
1890  ImplicitLocOpBuilder &b) -> Value {
1891  if (size == 1) {
1892  return b.create<arith::ConstantIndexOp>(0);
1893  }
1894 
1895  Value pred;
1896  if (floatingPointMode) {
1897  auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1898  pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1899  } else {
1900  Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1901  pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1902  dvalDouble, scale);
1903  }
1904 
1905  auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1906  val = b.create<arith::AddIOp>(val, offset);
1907  val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
1908  return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1909  };
1910 
1911  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1912  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1913 
1914  Value result = b.create<tensor::ExtractOp>(
1915  input, ValueRange{batch, iy, ix, channel});
1916 
1917  b.create<linalg::YieldOp>(result);
1918  } else {
1919  // The mode here must be BILINEAR.
1920  assert(op.getMode() == "BILINEAR");
1921 
1922  auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1923 
1924  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1926  val0 = in;
1927  val1 = b.create<arith::AddIOp>(val0, oneVal);
1928  val0 =
1929  clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1930  val1 =
1931  clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
1932  val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1933  val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1934  };
1935 
1936  // Linalg equivalent to the section below:
1937  // int16_t iy0 = apply_max(iy, 0);
1938  // int16_t iy1 = apply_min(iy + 1, IH - 1);
1939  // int16_t ix0 = apply_max(ix, 0);
1940  // int16_t ix1 = apply_min(ix + 1, IW - 1);
1941  Value x0, x1, y0, y1;
1942  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1943  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1944 
1945  Value y0x0 = b.create<tensor::ExtractOp>(
1946  input, ValueRange{batch, y0, x0, channel});
1947  Value y0x1 = b.create<tensor::ExtractOp>(
1948  input, ValueRange{batch, y0, x1, channel});
1949  Value y1x0 = b.create<tensor::ExtractOp>(
1950  input, ValueRange{batch, y1, x0, channel});
1951  Value y1x1 = b.create<tensor::ExtractOp>(
1952  input, ValueRange{batch, y1, x1, channel});
1953 
1954  if (floatingPointMode) {
1955  auto oneVal =
1956  b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1957  auto interpolate = [&](Value val0, Value val1, Value delta,
1958  int inputSize,
1959  ImplicitLocOpBuilder &b) -> Value {
1960  if (inputSize == 1)
1961  return val0;
1962  Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1963  Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1964  Value mul1 = b.create<arith::MulFOp>(val1, delta);
1965  return b.create<arith::AddFOp>(mul0, mul1);
1966  };
1967 
1968  // Linalg equivalent to the section below:
1969  // topAcc = v00 * (unit_x - dx);
1970  // topAcc += v01 * dx;
1971  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1972 
1973  // Linalg equivalent to the section below:
1974  // bottomAcc = v10 * (unit_x - dx);
1975  // bottomAcc += v11 * dx;
1976  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1977 
1978  // Linalg equivalent to the section below:
1979  // result = topAcc * (unit_y - dy) + bottomAcc * dy
1980  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1981  b.create<linalg::YieldOp>(result);
1982  } else {
1983  // Perform in quantized space.
1984  y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1985  y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1986  y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1987  y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1988 
1989  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1990  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1991  dx = b.create<arith::ExtSIOp>(resultETy, dx);
1992  dy = b.create<arith::ExtSIOp>(resultETy, dy);
1993  }
1994 
1995  Value yScaleNExt = yScaleN;
1996  Value xScaleNExt = xScaleN;
1997 
1998  const int64_t scaleBitwidth =
1999  xScaleN.getType().getIntOrFloatBitWidth();
2000  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2001  yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
2002  xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
2003  }
2004 
2005  auto interpolate = [](Value val0, Value val1, Value weight1,
2006  Value scale, int inputSize,
2007  ImplicitLocOpBuilder &b) -> Value {
2008  if (inputSize == 1)
2009  return b.create<arith::MulIOp>(val0, scale);
2010  Value weight0 = b.create<arith::SubIOp>(scale, weight1);
2011  Value mul0 = b.create<arith::MulIOp>(val0, weight0);
2012  Value mul1 = b.create<arith::MulIOp>(val1, weight1);
2013  return b.create<arith::AddIOp>(mul0, mul1);
2014  };
2015 
2016  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2017  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2018  Value result =
2019  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2020  b.create<linalg::YieldOp>(result);
2021  }
2022  }
2023  }
2024 
2025  rewriter.replaceOp(op, resize);
2026  return success();
2027  }
2028 };
2029 
2030 // At the codegen level any identity operations should be removed. Any cases
2031 // where identity is load-bearing (e.g. cross device computation) should be
2032 // handled before lowering to codegen.
2033 template <typename SrcOp>
2034 class IdentityNConverter : public OpRewritePattern<SrcOp> {
2035 public:
2037 
2038  LogicalResult matchAndRewrite(SrcOp op,
2039  PatternRewriter &rewriter) const final {
2040  rewriter.replaceOp(op, op.getOperation()->getOperands());
2041  return success();
2042  }
2043 };
2044 
2045 template <typename SrcOp>
2046 class ReduceConverter : public OpRewritePattern<SrcOp> {
2047 public:
2049 
2050  LogicalResult matchAndRewrite(SrcOp reduceOp,
2051  PatternRewriter &rewriter) const final {
2052  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2053  }
2054 };
2055 
2056 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2057 public:
2059 
2060  LogicalResult matchAndRewrite(tosa::ReverseOp op,
2061  PatternRewriter &rewriter) const final {
2062  auto loc = op.getLoc();
2063  Value input = op.getInput1();
2064  auto inputTy = cast<ShapedType>(input.getType());
2065  auto resultTy = cast<ShapedType>(op.getType());
2066  auto axis = op.getAxis();
2067 
2068  SmallVector<Value> dynDims;
2069  for (int i = 0; i < inputTy.getRank(); i++) {
2070  if (inputTy.isDynamicDim(i)) {
2071  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2072  }
2073  }
2074 
2075  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
2076 
2077  // First fill the output buffer with the init value.
2078  auto emptyTensor = rewriter
2079  .create<tensor::EmptyOp>(loc, inputTy.getShape(),
2080  inputTy.getElementType(),
2081  ArrayRef<Value>({dynDims}))
2082  .getResult();
2083  SmallVector<AffineMap, 2> affineMaps = {
2084  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2085 
2086  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2087  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
2088  getNParallelLoopsAttrs(resultTy.getRank()),
2089  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2090  llvm::SmallVector<Value> indices;
2091  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2092  Value index =
2093  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
2094  if (i == axis) {
2095  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
2096  auto sizeMinusOne =
2097  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
2098  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
2099  index);
2100  }
2101 
2102  indices.push_back(index);
2103  }
2104 
2105  auto extract = nestedBuilder.create<tensor::ExtractOp>(
2106  nestedLoc, input, indices);
2107  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
2108  extract.getResult());
2109  });
2110  return success();
2111  }
2112 };
2113 
2114 // This converter translate a tile operation to a reshape, broadcast, reshape.
2115 // The first reshape minimally expands each tiled dimension to include a
2116 // proceding size-1 dim. This dim is then broadcasted to the appropriate
2117 // multiple.
2118 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2120 
2121  LogicalResult
2122  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2123  ConversionPatternRewriter &rewriter) const override {
2124  auto loc = op.getLoc();
2125  auto input = op.getInput1();
2126  auto inputTy = cast<ShapedType>(input.getType());
2127  auto inputShape = inputTy.getShape();
2128  auto resultTy = cast<ShapedType>(op.getType());
2129  auto elementTy = inputTy.getElementType();
2130  int64_t rank = inputTy.getRank();
2131 
2132  SmallVector<int64_t> multiples;
2133  if (failed(op.getConstantMultiples(multiples)))
2134  return failure();
2135 
2136  // Broadcast the newly added dimensions to their appropriate multiple.
2137  SmallVector<int64_t, 2> genericShape;
2138  for (int i = 0; i < rank; i++) {
2139  int64_t dim = multiples[i];
2140  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2141  genericShape.push_back(inputShape[i]);
2142  }
2143 
2144  SmallVector<Value> dynDims;
2145  for (int i = 0; i < inputTy.getRank(); i++) {
2146  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2147  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2148  }
2149  }
2150 
2151  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
2152  op.getLoc(), genericShape, elementTy, dynDims);
2153 
2154  // We needs to map the input shape to the non-broadcasted dimensions.
2155  SmallVector<AffineExpr, 4> dimExprs;
2156  dimExprs.reserve(rank);
2157  for (unsigned i = 0; i < rank; ++i)
2158  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2159 
2160  auto readAffineMap =
2161  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2162  rewriter.getContext());
2163 
2164  SmallVector<AffineMap, 2> affineMaps = {
2165  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2166 
2167  auto genericOp = rewriter.create<linalg::GenericOp>(
2168  loc, RankedTensorType::get(genericShape, elementTy), input,
2169  ValueRange{emptyTensor}, affineMaps,
2170  getNParallelLoopsAttrs(genericShape.size()),
2171  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2172  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
2173  });
2174 
2175  auto shapeValue = getTosaConstShape(
2176  rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
2177  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2178  op, resultTy, genericOp.getResult(0), shapeValue);
2179  return success();
2180  }
2181 };
2182 
2183 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2184 // op, producing two output buffers.
2185 //
2186 // The first output buffer contains the index of the found maximum value. It is
2187 // initialized to 0 and is resulting integer type.
2188 //
2189 // The second output buffer contains the maximum value found. It is initialized
2190 // to the minimum representable value of the input element type. After being
2191 // populated by indexed_generic, this buffer is disgarded as only the index is
2192 // requested.
2193 //
2194 // The indexed_generic op updates both the maximum value and index if the
2195 // current value exceeds the running max.
2196 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2197 public:
2199 
2200  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2201  PatternRewriter &rewriter) const final {
2202  auto loc = argmaxOp.getLoc();
2203  Value input = argmaxOp.getInput();
2204  auto inputTy = cast<ShapedType>(input.getType());
2205  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2206  auto inElementTy = inputTy.getElementType();
2207  auto outElementTy = resultTy.getElementType();
2208  int axis = argmaxOp.getAxis();
2209  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2210 
2211  if (!isa<IntegerType>(outElementTy))
2212  return rewriter.notifyMatchFailure(
2213  argmaxOp,
2214  "tosa.arg_max to linalg.* requires integer-like result type");
2215 
2216  SmallVector<Value> dynDims;
2217  for (int i = 0; i < inputTy.getRank(); i++) {
2218  if (inputTy.isDynamicDim(i) && i != axis) {
2219  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2220  }
2221  }
2222 
2223  // First fill the output buffer for the index.
2224  auto emptyTensorIdx = rewriter
2225  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2226  outElementTy, dynDims)
2227  .getResult();
2228  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
2229  loc, rewriter.getIntegerAttr(outElementTy, 0));
2230  auto filledTensorIdx =
2231  rewriter
2232  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
2233  ValueRange{emptyTensorIdx})
2234  .result();
2235 
2236  // Second fill the output buffer for the running max.
2237  auto emptyTensorMax = rewriter
2238  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2239  inElementTy, dynDims)
2240  .getResult();
2241  auto fillValueMaxAttr =
2242  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2243 
2244  if (!fillValueMaxAttr)
2245  return rewriter.notifyMatchFailure(
2246  argmaxOp, "unsupported tosa.argmax element type");
2247 
2248  auto fillValueMax =
2249  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2250  auto filledTensorMax =
2251  rewriter
2252  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2253  ValueRange{emptyTensorMax})
2254  .result();
2255 
2256  // We need to reduce along the arg-max axis, with parallel operations along
2257  // the rest.
2259  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2260  iteratorTypes[axis] = utils::IteratorType::reduction;
2261 
2262  SmallVector<AffineExpr, 2> srcExprs;
2263  SmallVector<AffineExpr, 2> dstExprs;
2264  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2265  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2266  if (axis != i)
2267  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2268  }
2269 
2270  bool didEncounterError = false;
2271  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2272  rewriter.getContext());
2273  auto linalgOp = rewriter.create<linalg::GenericOp>(
2274  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2275  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2276  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2277  ValueRange blockArgs) {
2278  auto newValue = blockArgs[0];
2279  auto oldIndex = blockArgs[1];
2280  auto oldValue = blockArgs[2];
2281 
2282  Value newIndex = rewriter.create<arith::IndexCastOp>(
2283  nestedLoc, oldIndex.getType(),
2284  rewriter.create<linalg::IndexOp>(loc, axis));
2285 
2286  Value predicate;
2287  if (isa<FloatType>(inElementTy)) {
2288  if (argmaxOp.getNanMode() == "IGNORE") {
2289  // Only update index & max value for non NaN values. If all
2290  // values are NaNs, the initial index will be return which is 0.
2291  predicate = rewriter.create<arith::CmpFOp>(
2292  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2293  } else {
2294  // Update max value if either of the following is true:
2295  // - new value is bigger
2296  // - cur max is not NaN and new value is NaN
2297  Value gt = rewriter.create<arith::CmpFOp>(
2298  nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
2299  Value oldNonNaN = rewriter.create<arith::CmpFOp>(
2300  nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
2301  predicate = rewriter.create<arith::AndIOp>(
2302  nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2303  }
2304  } else if (isa<IntegerType>(inElementTy)) {
2305  predicate = rewriter.create<arith::CmpIOp>(
2306  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2307  } else {
2308  didEncounterError = true;
2309  return;
2310  }
2311 
2312  auto resultMax = rewriter.create<arith::SelectOp>(
2313  nestedLoc, predicate, newValue, oldValue);
2314  auto resultIndex = rewriter.create<arith::SelectOp>(
2315  nestedLoc, predicate, newIndex, oldIndex);
2316  nestedBuilder.create<linalg::YieldOp>(
2317  nestedLoc, ValueRange({resultIndex, resultMax}));
2318  });
2319 
2320  if (didEncounterError)
2321  return rewriter.notifyMatchFailure(
2322  argmaxOp, "unsupported tosa.argmax element type");
2323 
2324  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2325  return success();
2326  }
2327 };
2328 
2329 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2330 public:
2332  LogicalResult
2333  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2334  ConversionPatternRewriter &rewriter) const final {
2335  auto input = adaptor.getOperands()[0];
2336  auto indices = adaptor.getOperands()[1];
2337 
2338  auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2339  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2340  if (!valuesTy || !resultTy)
2341  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2342 
2343  auto dynamicDims = inferDynamicDimsForGather(
2344  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2345 
2346  auto resultElementTy = resultTy.getElementType();
2347 
2348  auto loc = op.getLoc();
2349  auto emptyTensor =
2350  rewriter
2351  .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2352  dynamicDims)
2353  .getResult();
2354 
2355  SmallVector<AffineMap, 2> affineMaps = {
2357  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2358  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2359  rewriter.getContext()),
2360  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2361 
2362  auto genericOp = rewriter.create<linalg::GenericOp>(
2363  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2364  ValueRange{emptyTensor}, affineMaps,
2365  getNParallelLoopsAttrs(resultTy.getRank()),
2366  [&](OpBuilder &b, Location loc, ValueRange args) {
2367  auto indexValue = args[0];
2368  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2369  Value index1 = rewriter.create<arith::IndexCastOp>(
2370  loc, rewriter.getIndexType(), indexValue);
2371  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2372  Value extract = rewriter.create<tensor::ExtractOp>(
2373  loc, input, ValueRange{index0, index1, index2});
2374  rewriter.create<linalg::YieldOp>(loc, extract);
2375  });
2376  rewriter.replaceOp(op, genericOp.getResult(0));
2377  return success();
2378  }
2379 
2380  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2381  Location loc,
2382  Value values,
2383  Value indices) {
2384  llvm::SmallVector<Value> results;
2385 
2386  auto addDynamicDimension = [&](Value source, int64_t dim) {
2387  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2388  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2389  results.push_back(dimValue);
2390  };
2391 
2392  addDynamicDimension(values, 0);
2393  addDynamicDimension(indices, 1);
2394  addDynamicDimension(values, 2);
2395  return results;
2396  }
2397 };
2398 
2399 // Lowerings the TableOp to a series of gathers and numerica operations. This
2400 // includes interpolation between the high/low values. For the I8 varient, this
2401 // simplifies to a single gather operation.
2402 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2403 public:
2405 
2406  LogicalResult matchAndRewrite(tosa::TableOp op,
2407  PatternRewriter &rewriter) const final {
2408  auto loc = op.getLoc();
2409  Value input = op.getInput1();
2410  Value table = op.getTable();
2411  auto inputTy = cast<ShapedType>(input.getType());
2412  auto tableTy = cast<ShapedType>(table.getType());
2413  auto resultTy = cast<ShapedType>(op.getType());
2414 
2415  auto inputElementTy = inputTy.getElementType();
2416  auto tableElementTy = tableTy.getElementType();
2417  auto resultElementTy = resultTy.getElementType();
2418 
2419  SmallVector<Value> dynDims;
2420  for (int i = 0; i < resultTy.getRank(); ++i) {
2421  if (inputTy.isDynamicDim(i)) {
2422  dynDims.push_back(
2423  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2424  }
2425  }
2426 
2427  auto emptyTensor = rewriter
2428  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2429  resultElementTy, dynDims)
2430  .getResult();
2431 
2432  SmallVector<AffineMap, 2> affineMaps = {
2433  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2434  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2435 
2436  auto genericOp = rewriter.create<linalg::GenericOp>(
2437  loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2438  getNParallelLoopsAttrs(resultTy.getRank()));
2439  rewriter.replaceOp(op, genericOp.getResult(0));
2440 
2441  {
2442  OpBuilder::InsertionGuard regionGuard(rewriter);
2443  Block *block = rewriter.createBlock(
2444  &genericOp.getRegion(), genericOp.getRegion().end(),
2445  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2446 
2447  auto inputValue = block->getArgument(0);
2448  rewriter.setInsertionPointToStart(block);
2449  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2450  resultElementTy.isInteger(8)) {
2451  Value index = rewriter.create<arith::IndexCastOp>(
2452  loc, rewriter.getIndexType(), inputValue);
2453  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2454  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2455  index, offset);
2456  Value extract =
2457  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2458  rewriter.create<linalg::YieldOp>(loc, extract);
2459  return success();
2460  }
2461 
2462  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2463  resultElementTy.isInteger(32)) {
2464  Value extend = rewriter.create<arith::ExtSIOp>(
2465  loc, rewriter.getI32Type(), inputValue);
2466 
2467  auto offset = rewriter.create<arith::ConstantOp>(
2468  loc, rewriter.getI32IntegerAttr(32768));
2469  auto seven = rewriter.create<arith::ConstantOp>(
2470  loc, rewriter.getI32IntegerAttr(7));
2471  auto one = rewriter.create<arith::ConstantOp>(
2472  loc, rewriter.getI32IntegerAttr(1));
2473  auto b1111111 = rewriter.create<arith::ConstantOp>(
2474  loc, rewriter.getI32IntegerAttr(127));
2475 
2476  // Compute the index and fractional part from the input value:
2477  // value = value + 32768
2478  // index = value >> 7;
2479  // fraction = 0x01111111 & value
2480  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2481  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2482  Value fraction =
2483  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2484 
2485  // Extract the base and next values from the table.
2486  // base = (int32_t) table[index];
2487  // next = (int32_t) table[index + 1];
2488  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2489 
2490  index = rewriter.create<arith::IndexCastOp>(
2491  loc, rewriter.getIndexType(), index);
2492  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2493  loc, rewriter.getIndexType(), indexPlusOne);
2494 
2495  Value base =
2496  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2497  Value next = rewriter.create<tensor::ExtractOp>(
2498  loc, table, ValueRange{indexPlusOne});
2499 
2500  base =
2501  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2502  next =
2503  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2504 
2505  // Use the fractional part to interpolate between the input values:
2506  // result = (base << 7) + (next - base) * fraction
2507  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2508  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2509  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2510  Value result =
2511  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2512 
2513  rewriter.create<linalg::YieldOp>(loc, result);
2514 
2515  return success();
2516  }
2517  }
2518 
2519  return rewriter.notifyMatchFailure(
2520  op, "unable to create body for tosa.table op");
2521  }
2522 };
2523 
2524 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2526 
2527  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2528 
2529  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2530  OpFoldResult ofr) {
2531  auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2532  auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2533 
2534  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2535  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2536  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2537  return getAsOpFoldResult(plusOne);
2538  }
2539 
2540  static RankedTensorType
2541  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2542  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2543  // Get [N, H, W]
2544  auto dims = tensor::getMixedSizes(builder, loc, input);
2545 
2546  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2547  // output tensors.
2548  dims[2] = halfPlusOne(builder, loc, dims[2]);
2549 
2550  llvm::SmallVector<int64_t, 3> staticSizes;
2551  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2552 
2553  auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2554  return RankedTensorType::get(staticSizes, elementType);
2555  }
2556 
2557  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2558  RankedTensorType type,
2559  llvm::ArrayRef<Value> dynamicSizes) {
2560  auto emptyTensor =
2561  rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2562  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2563  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2564  auto filledTensor = rewriter
2565  .create<linalg::FillOp>(loc, ValueRange{fillValue},
2566  ValueRange{emptyTensor})
2567  .result();
2568  return filledTensor;
2569  }
2570 
2571  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2572  FloatType type, Value value) {
2573  auto integerVal = builder.create<arith::IndexCastUIOp>(
2574  loc,
2575  type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2576  : builder.getI32Type(),
2577  value);
2578 
2579  return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2580  }
2581 
2582  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2583  FloatType type, int64_t index) {
2584  auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2585  return castIndexToFloat(builder, loc, type, indexVal);
2586  }
2587 
2588  template <typename... Args>
2589  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2590  Args... args) {
2591  return {builder.getAffineDimExpr(args)...};
2592  }
2593 
2594  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2595  PatternRewriter &rewriter) const override {
2596  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2597  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2598  return rewriter.notifyMatchFailure(rfft2d,
2599  "only supports ranked tensors");
2600  }
2601 
2602  auto loc = rfft2d.getLoc();
2603  auto input = rfft2d.getInputReal();
2604  auto elementType =
2605  dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2606  if (!elementType)
2607  return rewriter.notifyMatchFailure(rfft2d,
2608  "only supports float element types");
2609 
2610  // Compute the output type and set of dynamic sizes
2611  llvm::SmallVector<Value> dynamicSizes;
2612  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2613 
2614  // Iterator types for the linalg.generic implementation
2616  utils::IteratorType::parallel, utils::IteratorType::parallel,
2617  utils::IteratorType::parallel, utils::IteratorType::reduction,
2618  utils::IteratorType::reduction};
2619 
2620  // Inputs/outputs to the linalg.generic implementation
2621  llvm::SmallVector<Value> genericOpInputs = {input};
2622  llvm::SmallVector<Value> genericOpOutputs = {
2623  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2624  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2625 
2626  // Indexing maps for input and output tensors
2627  auto indexingMaps = AffineMap::inferFromExprList(
2628  llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2629  affineDimsExpr(rewriter, 0, 1, 2),
2630  affineDimsExpr(rewriter, 0, 1, 2)},
2631  rewriter.getContext());
2632 
2633  // Width and height dimensions of the original input.
2634  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2635  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2636 
2637  // Constants and dimension sizes
2638  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2639  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2640  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2641  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2642 
2643  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2644  Value valReal = args[0];
2645  Value sumReal = args[1];
2646  Value sumImag = args[2];
2647 
2648  // Indices for angle computation
2649  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2650  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2651  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2652  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2653 
2654  // Calculating angle without integer parts of components as sin/cos are
2655  // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2656  // / W);
2657  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2658  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2659 
2660  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2661  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2662 
2663  auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2664  auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2665 
2666  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2667  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2668  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2669  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2670 
2671  // realComponent = valReal * cos(angle)
2672  // imagComponent = valReal * sin(angle)
2673  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2674  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2675  auto realComponent =
2676  builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2677  auto imagComponent =
2678  builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2679 
2680  // outReal = sumReal + realComponent
2681  // outImag = sumImag - imagComponent
2682  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2683  auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2684 
2685  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2686  };
2687 
2688  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2689  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2690  indexingMaps, iteratorTypes, buildBody);
2691 
2692  return success();
2693  }
2694 };
2695 
2696 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2698 
2699  LogicalResult matchAndRewrite(FFT2dOp fft2d,
2700  PatternRewriter &rewriter) const override {
2701  if (!llvm::all_of(fft2d->getOperandTypes(),
2702  RFFT2dConverter::isRankedTensor) ||
2703  !llvm::all_of(fft2d->getResultTypes(),
2704  RFFT2dConverter::isRankedTensor)) {
2705  return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2706  }
2707 
2708  Location loc = fft2d.getLoc();
2709  Value input_real = fft2d.getInputReal();
2710  Value input_imag = fft2d.getInputImag();
2711  BoolAttr inverse = fft2d.getInverseAttr();
2712 
2713  auto real_el_ty = cast<FloatType>(
2714  cast<ShapedType>(input_real.getType()).getElementType());
2715  [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2716  cast<ShapedType>(input_imag.getType()).getElementType());
2717 
2718  assert(real_el_ty == imag_el_ty);
2719 
2720  // Compute the output type and set of dynamic sizes
2721  SmallVector<Value> dynamicSizes;
2722 
2723  // Get [N, H, W]
2724  auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2725 
2726  SmallVector<int64_t, 3> staticSizes;
2727  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2728 
2729  auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2730 
2731  // Iterator types for the linalg.generic implementation
2732  SmallVector<utils::IteratorType, 5> iteratorTypes = {
2733  utils::IteratorType::parallel, utils::IteratorType::parallel,
2734  utils::IteratorType::parallel, utils::IteratorType::reduction,
2735  utils::IteratorType::reduction};
2736 
2737  // Inputs/outputs to the linalg.generic implementation
2738  SmallVector<Value> genericOpInputs = {input_real, input_imag};
2739  SmallVector<Value> genericOpOutputs = {
2740  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2741  dynamicSizes),
2742  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2743  dynamicSizes)};
2744 
2745  // Indexing maps for input and output tensors
2746  auto indexingMaps = AffineMap::inferFromExprList(
2747  ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2748  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2749  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2750  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2751  rewriter.getContext());
2752 
2753  // Width and height dimensions of the original input.
2754  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2755  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2756 
2757  // Constants and dimension sizes
2758  auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2759  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2760  Value constH =
2761  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2762  Value constW =
2763  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2764 
2765  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2766  Value valReal = args[0];
2767  Value valImag = args[1];
2768  Value sumReal = args[2];
2769  Value sumImag = args[3];
2770 
2771  // Indices for angle computation
2772  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2773  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2774  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2775  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2776 
2777  // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2778  // ox) % W ) / W);
2779  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2780  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2781 
2782  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2783  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2784 
2785  auto iyRemFloat =
2786  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2787  auto ixRemFloat =
2788  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2789 
2790  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2791  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2792 
2793  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2794  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2795 
2796  if (inverse.getValue()) {
2797  angle = builder.create<arith::MulFOp>(
2798  loc, angle,
2799  rewriter.create<arith::ConstantOp>(
2800  loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2801  }
2802 
2803  // realComponent = val_real * cos(a) + val_imag * sin(a);
2804  // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2805  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2806  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2807 
2808  auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2809  auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2810  auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2811 
2812  auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2813  auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2814 
2815  auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2816 
2817  // outReal = sumReal + realComponent
2818  // outImag = sumImag - imagComponent
2819  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2820  auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2821 
2822  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2823  };
2824 
2825  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2826  fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2827  indexingMaps, iteratorTypes, buildBody);
2828 
2829  return success();
2830  }
2831 };
2832 
2833 } // namespace
2834 
2836  const TypeConverter &converter, RewritePatternSet *patterns) {
2837 
2838  // We have multiple resize coverters to handle degenerate cases.
2839  patterns->add<GenericResizeConverter>(patterns->getContext(),
2840  /*benefit=*/100);
2841  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2842  /*benefit=*/200);
2843  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2844  /*benefit=*/300);
2845 
2846  patterns->add<
2847  // clang-format off
2848  PointwiseConverter<tosa::AddOp>,
2849  PointwiseConverter<tosa::SubOp>,
2850  PointwiseConverter<tosa::MulOp>,
2851  PointwiseConverter<tosa::IntDivOp>,
2852  PointwiseConverter<tosa::NegateOp>,
2853  PointwiseConverter<tosa::PowOp>,
2854  PointwiseConverter<tosa::ReciprocalOp>,
2855  PointwiseConverter<tosa::RsqrtOp>,
2856  PointwiseConverter<tosa::LogOp>,
2857  PointwiseConverter<tosa::ExpOp>,
2858  PointwiseConverter<tosa::AbsOp>,
2859  PointwiseConverter<tosa::SinOp>,
2860  PointwiseConverter<tosa::CosOp>,
2861  PointwiseConverter<tosa::TanhOp>,
2862  PointwiseConverter<tosa::ErfOp>,
2863  PointwiseConverter<tosa::BitwiseAndOp>,
2864  PointwiseConverter<tosa::BitwiseOrOp>,
2865  PointwiseConverter<tosa::BitwiseNotOp>,
2866  PointwiseConverter<tosa::BitwiseXorOp>,
2867  PointwiseConverter<tosa::LogicalAndOp>,
2868  PointwiseConverter<tosa::LogicalNotOp>,
2869  PointwiseConverter<tosa::LogicalOrOp>,
2870  PointwiseConverter<tosa::LogicalXorOp>,
2871  PointwiseConverter<tosa::CastOp>,
2872  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2873  PointwiseConverter<tosa::LogicalRightShiftOp>,
2874  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2875  PointwiseConverter<tosa::ClzOp>,
2876  PointwiseConverter<tosa::SelectOp>,
2877  PointwiseConverter<tosa::GreaterOp>,
2878  PointwiseConverter<tosa::GreaterEqualOp>,
2879  PointwiseConverter<tosa::EqualOp>,
2880  PointwiseConverter<tosa::MaximumOp>,
2881  PointwiseConverter<tosa::MinimumOp>,
2882  PointwiseConverter<tosa::CeilOp>,
2883  PointwiseConverter<tosa::FloorOp>,
2884  PointwiseConverter<tosa::ClampOp>,
2885  PointwiseConverter<tosa::SigmoidOp>
2886  >(converter, patterns->getContext());
2887 
2888  patterns->add<
2889  IdentityNConverter<tosa::IdentityOp>,
2890  ReduceConverter<tosa::ReduceAllOp>,
2891  ReduceConverter<tosa::ReduceAnyOp>,
2892  ReduceConverter<tosa::ReduceMinOp>,
2893  ReduceConverter<tosa::ReduceMaxOp>,
2894  ReduceConverter<tosa::ReduceSumOp>,
2895  ReduceConverter<tosa::ReduceProductOp>,
2896  ArgMaxConverter,
2897  GatherConverter,
2898  RescaleConverter,
2899  ReverseConverter,
2900  RFFT2dConverter,
2901  FFT2dConverter,
2902  TableConverter,
2903  TileConverter>(patterns->getContext());
2904  // clang-format on
2905 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static arith::ConstantOp createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType, OpBuilder &rewriter)
static bool operandsAndResultsRanked(Operation *operation)
const float * table
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:60
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:323