MLIR  21.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 
34 #include <numeric>
35 #include <type_traits>
36 
37 using namespace mlir;
38 using namespace mlir::tosa;
39 
40 // Helper function to materialize the semantically correct compare and select
41 // operations given a binary operation with a specific NaN propagation mode.
42 //
43 // In the case of "PROPAGATE" semantics no compare and selection is required and
44 // this function does nothing.
45 //
46 // In the case of "IGNORE" semantics this function materializes a comparison of
47 // the current operands to the op which will return true for any NaN
48 // argument and then selects between the non-NaN operation argument and the
49 // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
50 // code:
51 //
52 // In the case that the op is operating on non floating point types we ignore
53 // the attribute completely, this is consistent with the TOSA spec which has
54 // the following wording: "This attribute is ignored by non floating-point
55 // types."
56 //
57 // binary<op>(lhs, rhs):
58 // result = op(lhs, rhs)
59 // if lhs == NaN return rhs
60 // if rhs == NaN return lhs
61 // return result
62 template <typename OpTy>
63 static Value
65  Value lhs, Value rhs, Value result) {
66  // NaN propagation has no meaning for non floating point types.
67  if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
68  return result;
69 
70  auto nanMode = op.getNanMode();
71  if (nanMode == "PROPAGATE")
72  return result;
73 
74  // Unordered comparison of NaN against itself will always return true.
75  Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
76  op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
77  Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
78  op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
79  Value rhsOrResult =
80  rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
81  return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
82  rhsOrResult);
83 }
84 
86  Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
87  ConversionPatternRewriter &rewriter) {
88  Location loc = op->getLoc();
89  auto elementTy =
90  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
91 
92  // tosa::AbsOp
93  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
94  return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
95 
96  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
97  auto zero = rewriter.create<arith::ConstantOp>(
98  loc, rewriter.getZeroAttr(elementTy));
99  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
100  return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
101  }
102 
103  // tosa::AddOp
104  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
105  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
106 
107  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
108  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
109 
110  // tosa::SubOp
111  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
112  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
113 
114  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
115  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
116 
117  // tosa::IntDivOp
118  if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
119  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
120 
121  // tosa::ReciprocalOp
122  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
123  auto one =
124  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
125  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
126  }
127 
128  // tosa::MulOp
129  if (isa<tosa::MulOp>(op)) {
130  auto shiftVal = cast<tosa::MulOp>(op).getShift();
131  DenseElementsAttr shiftElem;
132  if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
133  (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
134  return nullptr;
135  }
136 
137  int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
138 
139  if (isa<FloatType>(elementTy)) {
140  if (shift != 0) {
141  (void)rewriter.notifyMatchFailure(op,
142  "Cannot have shift value for float");
143  return nullptr;
144  }
145  return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
146  }
147 
148  if (isa<IntegerType>(elementTy)) {
149  Value a = args[0];
150  Value b = args[1];
151 
152  if (shift > 0) {
153  auto shiftConst =
154  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
155  if (!a.getType().isInteger(32))
156  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
157 
158  if (!b.getType().isInteger(32))
159  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
160 
161  auto result = rewriter.create<tosa::ApplyScaleOp>(
162  loc, rewriter.getI32Type(), a, b, shiftConst,
163  rewriter.getStringAttr("SINGLE_ROUND"));
164 
165  if (elementTy.isInteger(32))
166  return result;
167 
168  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
169  }
170 
171  int aWidth = a.getType().getIntOrFloatBitWidth();
172  int bWidth = b.getType().getIntOrFloatBitWidth();
173  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
174 
175  if (aWidth < cWidth)
176  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
177  if (bWidth < cWidth)
178  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
179 
180  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
181  }
182  }
183 
184  // tosa::NegateOp
185  if (isa<tosa::NegateOp>(op)) {
186  auto negate = cast<tosa::NegateOp>(op);
187 
188  FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
189  if (failed(maybeInZp)) {
190  (void)rewriter.notifyMatchFailure(
191  op, "input1 zero point cannot be statically determined");
192  return nullptr;
193  }
194 
195  FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
196  if (failed(maybeOutZp)) {
197  (void)rewriter.notifyMatchFailure(
198  op, "output zero point cannot be statically determined");
199  return nullptr;
200  }
201 
202  int64_t inZp = *maybeInZp;
203  int64_t outZp = *maybeOutZp;
204 
205  if (isa<FloatType>(elementTy))
206  return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
207 
208  if (isa<IntegerType>(elementTy)) {
209  if (!inZp && !outZp) {
210  auto constant = rewriter.create<arith::ConstantOp>(
211  loc, IntegerAttr::get(elementTy, 0));
212  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
213  args[0]);
214  }
215 
216  // Compute the maximum value that can occur in the intermediate buffer.
217  const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
218  const int64_t zpAdd = inZp + outZp;
219  const int64_t maxValue =
220  APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
221  std::abs(zpAdd) + 1;
222 
223  // Convert that maximum value into the maximum bitwidth needed to
224  // represent it. We assume 48-bit numbers may be supported further in
225  // the pipeline.
226  int intermediateBitWidth = 64;
227  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
228  intermediateBitWidth = 16;
229  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
230  intermediateBitWidth = 32;
231  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
232  intermediateBitWidth = 48;
233  }
234 
235  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
236  Value zpAddValue = rewriter.create<arith::ConstantOp>(
237  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
238 
239  // The negation can be applied by doing:
240  // outputValue = inZp + outZp - inputValue
241  auto ext =
242  rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
243  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
244 
245  // Clamp to the negation range.
246  Value min = rewriter.create<arith::ConstantIntOp>(
247  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
248  intermediateType);
249  Value max = rewriter.create<arith::ConstantIntOp>(
250  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
251  intermediateType);
252  auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
253 
254  // Truncate to the final value.
255  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
256  }
257  }
258 
259  // tosa::BitwiseAndOp
260  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
261  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
262 
263  // tosa::BitwiseOrOp
264  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
265  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
266 
267  // tosa::BitwiseNotOp
268  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
269  auto allOnesAttr = rewriter.getIntegerAttr(
270  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
271  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
272  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
273  }
274 
275  // tosa::BitwiseXOrOp
276  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
277  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
278 
279  // tosa::LogicalLeftShiftOp
280  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
281  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
282 
283  // tosa::LogicalRightShiftOp
284  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
285  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
286 
287  // tosa::ArithmeticRightShiftOp
288  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
289  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
290  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
291  if (!round) {
292  return result;
293  }
294 
295  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
296  auto one =
297  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
298  auto zero =
299  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
300  auto i1one =
301  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
302 
303  // Checking that input2 != 0
304  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
305  loc, arith::CmpIPredicate::sgt, args[1], zero);
306 
307  // Checking for the last bit of input1 to be 1
308  auto subtract =
309  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
310  auto shifted =
311  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
312  ->getResults();
313  auto truncated =
314  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
315  auto isInputOdd =
316  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
317 
318  auto shouldRound = rewriter.create<arith::AndIOp>(
319  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
320  auto extended =
321  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
322  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
323  }
324 
325  // tosa::ClzOp
326  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
327  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
328  }
329 
330  // tosa::LogicalAnd
331  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
332  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
333 
334  // tosa::LogicalNot
335  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
336  auto one = rewriter.create<arith::ConstantOp>(
337  loc, rewriter.getIntegerAttr(elementTy, 1));
338  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
339  }
340 
341  // tosa::LogicalOr
342  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
343  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
344 
345  // tosa::LogicalXor
346  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
347  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
348 
349  // tosa::PowOp
350  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
351  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
352 
353  // tosa::RsqrtOp
354  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
355  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
356 
357  // tosa::LogOp
358  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
359  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
360 
361  // tosa::ExpOp
362  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
363  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
364 
365  // tosa::SinOp
366  if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
367  return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
368 
369  // tosa::CosOp
370  if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
371  return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
372 
373  // tosa::TanhOp
374  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
375  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
376 
377  // tosa::ErfOp
378  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
379  return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
380 
381  // tosa::GreaterOp
382  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
383  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
384  args[0], args[1]);
385 
386  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
387  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
388  args[0], args[1]);
389 
390  // tosa::GreaterEqualOp
391  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
392  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
393  args[0], args[1]);
394 
395  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
396  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
397  args[0], args[1]);
398 
399  // tosa::EqualOp
400  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
401  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
402  args[0], args[1]);
403 
404  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
405  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
406  args[0], args[1]);
407 
408  // tosa::SelectOp
409  if (isa<tosa::SelectOp>(op)) {
410  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
411  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
412  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
413  }
414 
415  // tosa::MaximumOp
416  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
417  auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
418  return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
419  rewriter, args[0], args[1], max);
420  }
421 
422  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
423  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
424  }
425 
426  // tosa::MinimumOp
427  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
428  auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
429  return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
430  rewriter, args[0], args[1], min);
431  }
432 
433  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
434  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
435  }
436 
437  // tosa::CeilOp
438  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
439  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
440 
441  // tosa::FloorOp
442  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
443  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
444 
445  // tosa::ClampOp
446  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
447  bool losesInfo = false;
448  APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
449  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
450  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
451  APFloat::rmNearestTiesToEven, &losesInfo);
452  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
453  APFloat::rmNearestTiesToEven, &losesInfo);
454  auto min = rewriter.create<arith::ConstantOp>(
455  loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
456  auto max = rewriter.create<arith::ConstantOp>(
457  loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
458  auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
459 
460  auto clampOp = llvm::cast<tosa::ClampOp>(op);
461  const auto nanMode = clampOp.getNanMode();
462 
463  // NaN propagation has no meaning for non floating point types.
464  if (!isa<FloatType>(elementTy))
465  return result;
466 
467  // In the case of "PROPAGATE" semantics no compare and selection is
468  // required.
469  if (nanMode == "PROPAGATE")
470  return result;
471 
472  // In the case of "IGNORE" semantics materialize a comparison
473  // of the current operand to the reduction which will return true for a NaN
474  // argument and then selects between the initial reduction value and the
475  // calculated result based on whether the argument is NaN or not. In pseudo
476  // code:
477  //
478  // reduce<op>(x, init):
479  // result = op(init, x)
480  // return init if x == NaN else result
481 
482  // Unordered comparison of NaN against itself will always return true.
483  Value isNaN = rewriter.create<arith::CmpFOp>(
484  op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
485  // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
486  // is NaN.
487  return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
488  }
489 
490  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
491  auto intTy = cast<IntegerType>(elementTy);
492  int64_t min =
493  cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
494  int64_t max =
495  cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
496 
497  int64_t minRepresentable = std::numeric_limits<int64_t>::min();
498  int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
499  if (intTy.isUnsignedInteger()) {
500  minRepresentable = 0;
501  if (intTy.getIntOrFloatBitWidth() <= 63) {
502  maxRepresentable =
503  (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
504  .getZExtValue();
505  }
506  } else if (intTy.getIntOrFloatBitWidth() <= 64) {
507  // Ensure that min & max fit into signed n-bit constants.
508  minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
509  .getSExtValue();
510  maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
511  .getSExtValue();
512  }
513  // Ensure that the bounds are representable as n-bit signed/unsigned
514  // integers.
515  min = std::max(min, minRepresentable);
516  max = std::max(max, minRepresentable);
517  min = std::min(min, maxRepresentable);
518  max = std::min(max, maxRepresentable);
519 
520  auto minVal = rewriter.create<arith::ConstantIntOp>(
521  loc, min, intTy.getIntOrFloatBitWidth());
522  auto maxVal = rewriter.create<arith::ConstantIntOp>(
523  loc, max, intTy.getIntOrFloatBitWidth());
524  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
525  intTy.isUnsignedInteger());
526  }
527 
528  // tosa::SigmoidOp
529  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
530  auto one =
531  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
532  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
533  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
534  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
535  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
536  }
537 
538  // tosa::CastOp
539  if (isa<tosa::CastOp>(op)) {
540  Type srcTy = elementTy;
541  Type dstTy = resultTypes.front();
542  if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
543  (void)rewriter.notifyMatchFailure(op, "unsupported type");
544  return nullptr;
545  }
546 
547  bool bitExtend =
549 
550  if (srcTy == dstTy)
551  return args.front();
552 
553  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
554  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
555  std::nullopt);
556 
557  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
558  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
559  std::nullopt);
560 
561  // 1-bit integers need to be treated as signless.
562  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
563  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
564  std::nullopt);
565 
566  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
567  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
568  std::nullopt);
569 
570  // Unsigned integers need an unrealized cast so that they can be passed
571  // to UIToFP.
572  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
573  auto unrealizedCast =
574  rewriter
575  .create<UnrealizedConversionCastOp>(
576  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
577  args[0])
578  .getResult(0);
579  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
580  unrealizedCast);
581  }
582 
583  // All other si-to-fp conversions should be handled by SIToFP.
584  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
585  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
586  std::nullopt);
587 
588  // Casting to boolean, floats need to only be checked as not-equal to zero.
589  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
590  Value zero = rewriter.create<arith::ConstantOp>(
591  loc, rewriter.getFloatAttr(srcTy, 0.0));
592  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
593  args.front(), zero);
594  }
595 
596  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
597  auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
598 
599  const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
600  // Check whether neither int min nor int max can be represented in the
601  // input floating-point type due to too short exponent range.
602  if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
603  APFloat::semanticsMaxExponent(fltSemantics)) {
604  // Use cmp + select to replace infinites by int min / int max. Other
605  // integral values can be represented in the integer space.
606  auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
607  auto posInf = rewriter.create<arith::ConstantOp>(
608  loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
609  APFloat::getInf(fltSemantics)));
610  auto negInf = rewriter.create<arith::ConstantOp>(
611  loc, rewriter.getFloatAttr(
612  getElementTypeOrSelf(srcTy),
613  APFloat::getInf(fltSemantics, /*Negative=*/true)));
614  auto overflow = rewriter.create<arith::CmpFOp>(
615  loc, arith::CmpFPredicate::UEQ, rounded, posInf);
616  auto underflow = rewriter.create<arith::CmpFOp>(
617  loc, arith::CmpFPredicate::UEQ, rounded, negInf);
618  auto intMin = rewriter.create<arith::ConstantOp>(
619  loc, rewriter.getIntegerAttr(
620  getElementTypeOrSelf(dstTy),
621  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
622  auto intMax = rewriter.create<arith::ConstantOp>(
623  loc, rewriter.getIntegerAttr(
624  getElementTypeOrSelf(dstTy),
625  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
626  auto maxClamped =
627  rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
628  return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
629  maxClamped);
630  }
631 
632  auto intMinFP = rewriter.create<arith::ConstantOp>(
633  loc, rewriter.getFloatAttr(
634  getElementTypeOrSelf(srcTy),
635  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
636  .getSExtValue()));
637 
638  // Check whether the mantissa has enough bits to represent int max.
639  if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
640  dstTy.getIntOrFloatBitWidth() - 1) {
641  // Int min can also be represented since it is a power of two and thus
642  // consists of a single leading bit. Therefore we can clamp the input
643  // in the floating-point domain.
644 
645  auto intMaxFP = rewriter.create<arith::ConstantOp>(
646  loc, rewriter.getFloatAttr(
647  getElementTypeOrSelf(srcTy),
648  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
649  .getSExtValue()));
650 
651  Value clamped =
652  clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
653  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
654  }
655 
656  // Due to earlier check we know exponant range is big enough to represent
657  // int min. We can therefore rely on int max + 1 being representable as
658  // well because it's just int min with a positive sign. So clamp the min
659  // value and compare against that to select the max int value if needed.
660  auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
661  loc, rewriter.getFloatAttr(
662  getElementTypeOrSelf(srcTy),
663  static_cast<double>(
664  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
665  .getSExtValue()) +
666  1.0f));
667 
668  auto intMax = rewriter.create<arith::ConstantOp>(
669  loc, rewriter.getIntegerAttr(
670  getElementTypeOrSelf(dstTy),
671  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
672  auto minClampedFP =
673  rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
674  auto minClamped =
675  rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
676  auto overflow = rewriter.create<arith::CmpFOp>(
677  loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
678  return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
679  minClamped);
680  }
681 
682  // Casting to boolean, integers need to only be checked as not-equal to
683  // zero.
684  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
685  Value zero = rewriter.create<arith::ConstantIntOp>(
686  loc, 0, srcTy.getIntOrFloatBitWidth());
687  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
688  args.front(), zero);
689  }
690 
691  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
692  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
693  std::nullopt);
694 
695  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
696  return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
697  }
698  }
699 
700  (void)rewriter.notifyMatchFailure(
701  op, "unhandled op for linalg body calculation for elementwise op");
702  return nullptr;
703 }
704 
706 
707 // Emit an 'arith.constant' op for the given index if it has not been created
708 // yet, or return an existing constant. This will prevent an excessive creation
709 // of redundant constants, easing readability of emitted code for unit tests.
711  IndexPool &indexPool, int64_t index) {
712  auto [it, inserted] = indexPool.try_emplace(index);
713  if (inserted)
714  it->second =
715  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
716  return it->second;
717 }
718 
720  IndexPool &indexPool, Value tensor, int64_t index) {
721  auto indexValue = createIndex(rewriter, loc, indexPool, index);
722  return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
723 }
724 
726  IndexPool &indexPool, Value tensor,
727  int64_t index) {
728  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
729  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
730  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
731  if (shapedType.isDynamicDim(index))
732  return getTensorDim(rewriter, loc, indexPool, tensor, index);
733  return rewriter.getIndexAttr(shapedType.getDimSize(index));
734 }
735 
736 static bool operandsAndResultsRanked(Operation *operation) {
737  auto isRanked = [](Value value) {
738  return isa<RankedTensorType>(value.getType());
739  };
740  return llvm::all_of(operation->getOperands(), isRanked) &&
741  llvm::all_of(operation->getResults(), isRanked);
742 }
743 
744 // Compute the runtime dimension size for dimension 'dim' of the output by
745 // inspecting input 'operands', all of which are expected to have the same rank.
746 // This function returns a pair {targetSize, masterOperand}.
747 //
748 // The runtime size of the output dimension is returned either as a statically
749 // computed attribute or as a runtime SSA value.
750 //
751 // If the target size was inferred directly from one dominating operand, that
752 // operand is returned in 'masterOperand'. If the target size is inferred from
753 // multiple operands, 'masterOperand' is set to nullptr.
754 static std::pair<OpFoldResult, Value>
756  ValueRange operands, int64_t dim) {
757  // If any input operand contains a static size greater than 1 for this
758  // dimension, that is the target size. An occurrence of an additional static
759  // dimension greater than 1 with a different value is undefined behavior.
760  for (auto operand : operands) {
761  auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
762  if (!ShapedType::isDynamic(size) && size > 1)
763  return {rewriter.getIndexAttr(size), operand};
764  }
765 
766  // Filter operands with dynamic dimension
767  auto operandsWithDynamicDim =
768  llvm::filter_to_vector(operands, [&](Value operand) {
769  return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
770  });
771 
772  // If no operand has a dynamic dimension, it means all sizes were 1
773  if (operandsWithDynamicDim.empty())
774  return {rewriter.getIndexAttr(1), operands.front()};
775 
776  // Emit code that computes the runtime size for this dimension. If there is
777  // only one operand with a dynamic dimension, it is considered the master
778  // operand that determines the runtime size of the output dimension.
779  auto targetSize =
780  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
781  if (operandsWithDynamicDim.size() == 1)
782  return {targetSize, operandsWithDynamicDim[0]};
783 
784  // Calculate maximum size among all dynamic dimensions
785  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
786  auto nextSize =
787  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
788  targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
789  }
790  return {targetSize, nullptr};
791 }
792 
793 // Compute the runtime output size for all dimensions. This function returns
794 // a pair {targetShape, masterOperands}.
795 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
797  IndexPool &indexPool, ValueRange operands) {
798  assert(!operands.empty());
799  auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
800  SmallVector<OpFoldResult> targetShape;
801  SmallVector<Value> masterOperands;
802  for (auto dim : llvm::seq<int64_t>(0, rank)) {
803  auto [targetSize, masterOperand] =
804  computeTargetSize(rewriter, loc, indexPool, operands, dim);
805  targetShape.push_back(targetSize);
806  masterOperands.push_back(masterOperand);
807  }
808  return {targetShape, masterOperands};
809 }
810 
812  IndexPool &indexPool, Value operand,
813  int64_t dim, OpFoldResult targetSize,
814  Value masterOperand) {
815  // Nothing to do if this is a static dimension
816  auto rankedTensorType = cast<RankedTensorType>(operand.getType());
817  if (!rankedTensorType.isDynamicDim(dim))
818  return operand;
819 
820  // If the target size for this dimension was directly inferred by only taking
821  // this operand into account, there is no need to broadcast. This is an
822  // optimization that will prevent redundant control flow, and constitutes the
823  // main motivation for tracking "master operands".
824  if (operand == masterOperand)
825  return operand;
826 
827  // Affine maps for 'linalg.generic' op
828  auto rank = rankedTensorType.getRank();
829  SmallVector<AffineExpr> affineExprs;
830  for (auto index : llvm::seq<int64_t>(0, rank)) {
831  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
832  : rewriter.getAffineDimExpr(index);
833  affineExprs.push_back(affineExpr);
834  }
835  auto broadcastAffineMap =
836  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
837  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
838  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
839 
840  // Check if broadcast is necessary
841  auto one = createIndex(rewriter, loc, indexPool, 1);
842  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
843  auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
844  loc, arith::CmpIPredicate::eq, runtimeSize, one);
845 
846  // Emit 'then' region of 'scf.if'
847  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
848  // It is not safe to cache constants across regions.
849  // New constants could potentially violate dominance requirements.
850  IndexPool localPool;
851 
852  // Emit 'tensor.empty' op
853  SmallVector<OpFoldResult> outputTensorShape;
854  for (auto index : llvm::seq<int64_t>(0, rank)) {
855  auto size = index == dim ? targetSize
856  : getOrFoldTensorDim(rewriter, loc, localPool,
857  operand, index);
858  outputTensorShape.push_back(size);
859  }
860  Value outputTensor = opBuilder.create<tensor::EmptyOp>(
861  loc, outputTensorShape, rankedTensorType.getElementType());
862 
863  // Emit 'linalg.generic' op
864  auto resultTensor =
865  opBuilder
866  .create<linalg::GenericOp>(
867  loc, outputTensor.getType(), operand, outputTensor, affineMaps,
869  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
870  // Emit 'linalg.yield' op
871  opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
872  })
873  .getResult(0);
874 
875  // Cast to original operand type if necessary
876  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
877  loc, operand.getType(), resultTensor);
878 
879  // Emit 'scf.yield' op
880  opBuilder.create<scf::YieldOp>(loc, castResultTensor);
881  };
882 
883  // Emit 'else' region of 'scf.if'
884  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
885  opBuilder.create<scf::YieldOp>(loc, operand);
886  };
887 
888  // Emit 'scf.if' op
889  auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
890  emitThenRegion, emitElseRegion);
891  return ifOp.getResult(0);
892 }
893 
895  IndexPool &indexPool, Value operand,
896  ArrayRef<OpFoldResult> targetShape,
897  ArrayRef<Value> masterOperands) {
898  int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
899  assert((int64_t)targetShape.size() == rank);
900  assert((int64_t)masterOperands.size() == rank);
901  for (auto index : llvm::seq<int64_t>(0, rank))
902  operand =
903  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
904  targetShape[index], masterOperands[index]);
905  return operand;
906 }
907 
908 static SmallVector<Value>
910  IndexPool &indexPool, ValueRange operands,
911  ArrayRef<OpFoldResult> targetShape,
912  ArrayRef<Value> masterOperands) {
913  // No need to broadcast for unary operations
914  if (operands.size() == 1)
915  return operands;
916 
917  // Broadcast dynamic dimensions operand by operand
918  return llvm::map_to_vector(operands, [&](Value operand) {
919  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
920  targetShape, masterOperands);
921  });
922 }
923 
924 static LogicalResult
926  Operation *operation, ValueRange operands,
927  ArrayRef<OpFoldResult> targetShape,
928  const TypeConverter &converter) {
929  // Generate output tensor
930  auto resultType = cast_or_null<RankedTensorType>(
931  converter.convertType(operation->getResultTypes().front()));
932  if (!resultType) {
933  return rewriter.notifyMatchFailure(operation, "failed to convert type");
934  }
935  Value outputTensor = rewriter.create<tensor::EmptyOp>(
936  loc, targetShape, resultType.getElementType());
937 
938  // Create affine maps. Input affine maps broadcast static dimensions of size
939  // 1. The output affine map is an identity map.
940  //
941  auto rank = resultType.getRank();
942  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
943  auto shape = cast<ShapedType>(operand.getType()).getShape();
944  SmallVector<AffineExpr> affineExprs;
945  for (auto it : llvm::enumerate(shape)) {
946  // Prefer producting identity maps whenever possible (i.e. no broadcasting
947  // needed) because some transforms (like reshape folding)
948  // do not support affine constant exprs.
949  bool requiresBroadcast =
950  (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
951  auto affineExpr = requiresBroadcast
952  ? rewriter.getAffineConstantExpr(0)
953  : rewriter.getAffineDimExpr(it.index());
954  affineExprs.push_back(affineExpr);
955  }
956  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
957  });
958  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
959 
960  // Emit 'linalg.generic' op
961  bool encounteredError = false;
962  auto linalgOp = rewriter.create<linalg::GenericOp>(
963  loc, outputTensor.getType(), operands, outputTensor, affineMaps,
965  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
967  operation, blockArgs.take_front(operation->getNumOperands()),
968  {resultType.getElementType()}, rewriter);
969  if (!opResult) {
970  encounteredError = true;
971  return;
972  }
973  opBuilder.create<linalg::YieldOp>(loc, opResult);
974  });
975  if (encounteredError)
976  return rewriter.notifyMatchFailure(
977  operation, "unable to create linalg.generic body for elementwise op");
978 
979  // Cast 'linalg.generic' result into original result type if needed
980  auto castResult = rewriter.createOrFold<tensor::CastOp>(
981  loc, resultType, linalgOp->getResult(0));
982  rewriter.replaceOp(operation, castResult);
983  return success();
984 }
985 
987  ValueRange operands) {
988  // Shift cannot broadcast
989  if (isa<tosa::MulOp>(operation))
990  return operands.take_front(2);
991  // Input1_zp and output_zp cannot broadcast
992  if (isa<tosa::NegateOp>(operation))
993  return operands.take_front(1);
994  return operands;
995 }
996 
997 static LogicalResult
999  ConversionPatternRewriter &rewriter,
1000  const TypeConverter &converter) {
1001 
1002  // Collect op properties
1003  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1004  assert(operation->getNumOperands() >= 1 &&
1005  "elementwise op expects at least 1 operand");
1006  if (!operandsAndResultsRanked(operation))
1007  return rewriter.notifyMatchFailure(operation,
1008  "Unranked tensors not supported");
1009 
1010  // Lower operation
1011  IndexPool indexPool;
1012  auto loc = operation->getLoc();
1013  auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1014  auto [targetShape, masterOperands] =
1015  computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
1016  auto broadcastOperands =
1017  broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
1018  targetShape, masterOperands);
1019  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
1020  targetShape, converter);
1021 }
1022 
1023 // Returns the constant initial value for a given reduction operation. The
1024 // attribute type varies depending on the element type required.
1025 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1026  PatternRewriter &rewriter) {
1027  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1028  return rewriter.getFloatAttr(elementTy, 0.0);
1029 
1030  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1031  return rewriter.getIntegerAttr(elementTy, 0);
1032 
1033  if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1034  return rewriter.getFloatAttr(elementTy, 1.0);
1035 
1036  if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1037  return rewriter.getIntegerAttr(elementTy, 1);
1038 
1039  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1040  return rewriter.getFloatAttr(
1041  elementTy, APFloat::getLargest(
1042  cast<FloatType>(elementTy).getFloatSemantics(), false));
1043 
1044  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1045  return rewriter.getIntegerAttr(
1046  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
1047 
1048  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1049  return rewriter.getFloatAttr(
1050  elementTy, APFloat::getLargest(
1051  cast<FloatType>(elementTy).getFloatSemantics(), true));
1052 
1053  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1054  return rewriter.getIntegerAttr(
1055  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1056 
1057  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1058  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1059 
1060  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1061  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1062 
1063  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1064  return rewriter.getFloatAttr(
1065  elementTy, APFloat::getLargest(
1066  cast<FloatType>(elementTy).getFloatSemantics(), true));
1067 
1068  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1069  return rewriter.getIntegerAttr(
1070  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1071 
1072  return {};
1073 }
1074 
1075 // Creates the body calculation for a reduction. The operations vary depending
1076 // on the input type.
1078  ValueRange args,
1079  Type elementTy,
1080  PatternRewriter &rewriter) {
1081  Location loc = op->getLoc();
1082  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1083  return rewriter.create<arith::AddFOp>(loc, args);
1084  }
1085 
1086  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1087  return rewriter.create<arith::AddIOp>(loc, args);
1088  }
1089 
1090  if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1091  return rewriter.create<arith::MulFOp>(loc, args);
1092  }
1093 
1094  if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1095  return rewriter.create<arith::MulIOp>(loc, args);
1096  }
1097 
1098  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1099  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1100  }
1101 
1102  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1103  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1104  }
1105 
1106  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1107  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1108  }
1109 
1110  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1111  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1112  }
1113 
1114  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1115  return rewriter.create<arith::AndIOp>(loc, args);
1116 
1117  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1118  return rewriter.create<arith::OrIOp>(loc, args);
1119 
1120  return {};
1121 }
1122 
1123 // Performs the match and rewrite for reduction operations. This includes
1124 // declaring a correctly sized initial value, and the linalg.generic operation
1125 // that reduces across the specified axis.
1126 template <typename OpTy>
1127 static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1128  PatternRewriter &rewriter) {
1129  auto loc = op->getLoc();
1130  auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1131  auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1132  if (!inputTy || !resultTy)
1133  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1134 
1135  auto elementTy = resultTy.getElementType();
1136  Value input = op->getOperand(0);
1137 
1138  SmallVector<int64_t> reduceShape;
1139  SmallVector<Value> dynDims;
1140  for (unsigned i = 0; i < inputTy.getRank(); i++) {
1141  if (axis != i) {
1142  reduceShape.push_back(inputTy.getDimSize(i));
1143  if (inputTy.isDynamicDim(i))
1144  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1145  }
1146  }
1147 
1148  SmallVector<Value> inputs, outputs;
1149  inputs.push_back(input);
1150 
1151  // First fill the output buffer with the init value.
1152  auto emptyTensor =
1153  rewriter
1154  .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1155  dynDims)
1156  .getResult();
1157 
1158  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1159  if (!fillValueAttr)
1160  return rewriter.notifyMatchFailure(
1161  op, "No initial value found for reduction operation");
1162 
1163  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1164  auto filledTensor = rewriter
1165  .create<linalg::FillOp>(loc, ValueRange{fillValue},
1166  ValueRange{emptyTensor})
1167  .result();
1168  outputs.push_back(filledTensor);
1169 
1170  bool isNanIgnoreMode = false;
1171  if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1172  std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1173  // NaN propagation has no meaning for non floating point types.
1174  if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
1175  isNanIgnoreMode = true;
1176  // Because the TOSA spec requires the result be NaN iff all elements in
1177  // the reduction are NaN we can't simply perform a compare and select.
1178  // Additionally we have to keep track of whether we've seen any non-NaN
1179  // values and then do a final select based on this predicate.
1180  auto trueAttr = rewriter.getBoolAttr(true);
1181  auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
1182  auto emptyBoolTensor =
1183  rewriter
1184  .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
1185  dynDims)
1186  .getResult();
1187  auto allResultsNaNTensor =
1188  rewriter
1189  .create<linalg::FillOp>(loc, ValueRange{trueValue},
1190  ValueRange{emptyBoolTensor})
1191  .result();
1192  // Note that because the linalg::ReduceOp has two variadic arguments
1193  // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1194  // need to have the same number of inputs and outputs.
1195  //
1196  // The second input isn't actually used anywhere since the value used to
1197  // update the NaN flag is calculated inside the body of the reduction and
1198  // then used to update an out value.
1199  // In order to satisfy type constraints we just pass another copy of the
1200  // input here.
1201  inputs.push_back(input);
1202  outputs.push_back(allResultsNaNTensor);
1203  }
1204  }
1205 
1206  bool didEncounterError = false;
1207  linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
1208  loc, inputs, outputs, axis,
1209  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1210  std::array<Value, 2> binaryArgs{
1211  blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1213  op, binaryArgs, elementTy, rewriter);
1214  if (result)
1215  didEncounterError = true;
1216 
1217  SmallVector<Value> resultsToYield;
1218  if (isNanIgnoreMode) {
1219  auto inputValue = blockArgs[0];
1220  auto initialValue = blockArgs[2];
1221  auto oldAllResultsNanFlagValue = blockArgs[3];
1222 
1223  // Unordered comparison of NaN against itself will always return true.
1224  Value isNaN = nestedBuilder.create<arith::CmpFOp>(
1225  op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1226  // If we've encountered a NaN, take the non-NaN value.
1227  auto selectOp = nestedBuilder.create<arith::SelectOp>(
1228  op->getLoc(), isNaN, initialValue, result);
1229  // Update the flag which keeps track of whether we have seen a non-NaN
1230  // value.
1231  auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
1232  op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1233  resultsToYield.push_back(selectOp);
1234  resultsToYield.push_back(newAllResultsNanFlagValue);
1235  } else {
1236  resultsToYield.push_back(result);
1237  }
1238  nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
1239  });
1240 
1241  if (!didEncounterError)
1242  return rewriter.notifyMatchFailure(
1243  op, "unable to create linalg.generic body for reduce op");
1244 
1245  if (isNanIgnoreMode) {
1246  // Materialize a check to see whether we encountered any non-NaN values, if
1247  // we didn't we need to select a tensor of NaNs since the result will just
1248  // be the initial identity value propagated through all the compares and
1249  // selects inside the reduction.
1250 
1251  // Create a tensor full of NaNs.
1252  auto nanValueAttr = rewriter.getFloatAttr(
1253  elementTy,
1254  APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1255  auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
1256  auto emptyNanTensor =
1257  rewriter
1258  .create<tensor::EmptyOp>(loc, reduceShape,
1259  resultTy.getElementType(), dynDims)
1260  .getResult();
1261  auto nanFilledTensor =
1262  rewriter
1263  .create<linalg::FillOp>(loc, ValueRange{nanValue},
1264  ValueRange{emptyNanTensor})
1265  .result();
1266 
1267  // Create an empty tensor, non need to fill this since it will be
1268  // overwritten by the select.
1269  auto finalEmptyTensor =
1270  rewriter
1271  .create<tensor::EmptyOp>(loc, reduceShape,
1272  resultTy.getElementType(), dynDims)
1273  .getResult();
1274 
1275  // Do a selection between the tensors akin to:
1276  // result = NaN if "all results NaN" else result.
1277  SmallVector<Value> ins, outs;
1278  ins.push_back(linalgOp->getOpResult(1));
1279  ins.push_back(nanFilledTensor);
1280  ins.push_back(linalgOp->getResult(0));
1281  outs.push_back(finalEmptyTensor);
1282  auto linalgSelect =
1283  rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
1284  linalgOp = linalgSelect;
1285  }
1286 
1287  SmallVector<ReassociationExprs, 4> reassociationMap;
1288  uint64_t expandInputRank =
1289  cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
1290  reassociationMap.resize(expandInputRank);
1291 
1292  for (uint64_t i = 0; i < expandInputRank; i++) {
1293  int32_t dimToPush = i > axis ? i + 1 : i;
1294  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1295  }
1296 
1297  if (expandInputRank != 0) {
1298  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1299  reassociationMap[expandedDim].push_back(
1300  rewriter.getAffineDimExpr(expandedDim + 1));
1301  }
1302 
1303  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1304  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1305  // not have access to such information. This matters when handling dynamically
1306  // sized tensors.
1307  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1308  op, resultTy, linalgOp->getResults()[0], reassociationMap);
1309  return success();
1310 }
1311 
1312 namespace {
1313 
1314 template <typename SrcOp>
1315 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1316 public:
1319 
1320  LogicalResult
1321  matchAndRewrite(SrcOp op, OpAdaptor operands,
1322  ConversionPatternRewriter &rewriter) const final {
1324  op, operands.getOperands(), rewriter, *this->getTypeConverter());
1325  }
1326 };
1327 
1328 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1329 public:
1331 
1332  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1333  PatternRewriter &rewriter) const final {
1334  auto loc = op.getLoc();
1335  auto input = op.getInput();
1336  auto inputTy = cast<ShapedType>(op.getInput().getType());
1337  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1338  unsigned rank = inputTy.getRank();
1339 
1340  // This is an illegal configuration. terminate and log an error
1341  if (op.getRoundingMode() == "INEXACT_ROUND")
1342  return rewriter.notifyMatchFailure(
1343  op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1344  "currently supported");
1345  if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
1346  return rewriter.notifyMatchFailure(
1347  op, "tosa.rescale requires scale32 for double_round to be true");
1348 
1349  if (!isa<IntegerType>(inputTy.getElementType()))
1350  return rewriter.notifyMatchFailure(op, "only support integer type");
1351 
1352  SmallVector<Value> dynDims;
1353  for (int i = 0; i < outputTy.getRank(); i++) {
1354  if (outputTy.isDynamicDim(i)) {
1355  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1356  }
1357  }
1358 
1359  // The shift and multiplier values.
1360  DenseElementsAttr shiftElems;
1361  if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1362  return rewriter.notifyMatchFailure(
1363  op, "tosa.rescale requires constant shift input values");
1364 
1365  DenseElementsAttr multiplierElems;
1366  if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1367  return rewriter.notifyMatchFailure(
1368  op, "tosa.rescale requires constant multiplier input values");
1369 
1370  llvm::SmallVector<int8_t> shiftValues =
1371  llvm::to_vector(shiftElems.getValues<int8_t>());
1372  // explicit cast is required here
1373  llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1374  llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1375  [](IntegerAttr attr) -> int32_t {
1376  return static_cast<int32_t>(attr.getInt());
1377  }));
1378 
1379  // If we shift by more than the bitwidth, this just sets to 0.
1380  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1381  if (shiftValues[i] > 63) {
1382  shiftValues[i] = 0;
1383  multiplierValues[i] = 0;
1384  }
1385  }
1386 
1387  // Double round only occurs if shift is greater than 31, check that this
1388  // is ever true.
1389 
1390  bool doubleRound =
1391  op.getRoundingMode() == "DOUBLE_ROUND" &&
1392  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1393  StringAttr roundingMode = doubleRound
1394  ? rewriter.getStringAttr("DOUBLE_ROUND")
1395  : rewriter.getStringAttr("SINGLE_ROUND");
1396 
1397  SmallVector<AffineMap> indexingMaps = {
1398  rewriter.getMultiDimIdentityMap(rank)};
1399  SmallVector<Value, 4> genericInputs = {input};
1400 
1401  // If we are rescaling per-channel then we need to store the multiplier
1402  // values in a buffer.
1403  Value multiplierConstant;
1404  int64_t multiplierArg = 0;
1405  if (multiplierValues.size() == 1) {
1406  multiplierConstant = rewriter.create<arith::ConstantOp>(
1407  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1408  } else {
1409  SmallVector<AffineExpr, 2> multiplierExprs{
1410  rewriter.getAffineDimExpr(rank - 1)};
1411  auto multiplierType =
1412  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1413  rewriter.getI32Type());
1414  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1415  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1416 
1417  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1418  /*symbolCount=*/0, multiplierExprs,
1419  rewriter.getContext()));
1420 
1421  multiplierArg = indexingMaps.size() - 1;
1422  }
1423 
1424  // If we are rescaling per-channel then we need to store the shift
1425  // values in a buffer.
1426  Value shiftConstant;
1427  int64_t shiftArg = 0;
1428  if (shiftValues.size() == 1) {
1429  shiftConstant = rewriter.create<arith::ConstantOp>(
1430  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1431  } else {
1432  SmallVector<AffineExpr, 2> shiftExprs = {
1433  rewriter.getAffineDimExpr(rank - 1)};
1434  auto shiftType =
1435  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1436  rewriter.getIntegerType(8));
1437  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1438  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1439  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1440  /*symbolCount=*/0, shiftExprs,
1441  rewriter.getContext()));
1442  shiftArg = indexingMaps.size() - 1;
1443  }
1444 
1445  // Indexing maps for output values.
1446  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1447 
1448  // Construct the indexing maps needed for linalg.generic ops.
1449  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1450  loc, outputTy.getShape(), outputTy.getElementType(),
1451  ArrayRef<Value>({dynDims}));
1452 
1453  auto linalgOp = rewriter.create<linalg::GenericOp>(
1454  loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1455  getNParallelLoopsAttrs(rank),
1456  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1457  ValueRange blockArgs) {
1458  Value value = blockArgs[0];
1459  Type valueTy = value.getType();
1460 
1461  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1462  if (failed(maybeIZp)) {
1463  (void)rewriter.notifyMatchFailure(
1464  op, "input zero point cannot be statically determined");
1465  return;
1466  }
1467 
1468  const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1469  // Extend zeropoint for sub-32bits widths.
1470  const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1471  auto inputZp = nestedBuilder.create<arith::ConstantOp>(
1472  loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
1473  *maybeIZp));
1474 
1475  FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1476  if (failed(maybeOZp)) {
1477  (void)rewriter.notifyMatchFailure(
1478  op, "output zero point cannot be statically determined");
1479  return;
1480  };
1481 
1482  IntegerType outIntType =
1483  cast<IntegerType>(blockArgs.back().getType());
1484  unsigned outBitWidth = outIntType.getWidth();
1485  const int32_t outAttrBitwidth = 32;
1486  assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1487  auto outputZp = nestedBuilder.create<arith::ConstantOp>(
1488  loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
1489  *maybeOZp));
1490 
1491  Value multiplier = multiplierConstant ? multiplierConstant
1492  : blockArgs[multiplierArg];
1493  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1494 
1495  if (valueTy.isUnsignedInteger()) {
1496  value = nestedBuilder
1497  .create<UnrealizedConversionCastOp>(
1498  nestedLoc,
1499  nestedBuilder.getIntegerType(
1500  valueTy.getIntOrFloatBitWidth()),
1501  value)
1502  .getResult(0);
1503  }
1504  if (valueTy.getIntOrFloatBitWidth() < 32) {
1505  if (op.getInputUnsigned()) {
1506  value = nestedBuilder.create<arith::ExtUIOp>(
1507  nestedLoc, nestedBuilder.getI32Type(), value);
1508  } else {
1509  value = nestedBuilder.create<arith::ExtSIOp>(
1510  nestedLoc, nestedBuilder.getI32Type(), value);
1511  }
1512  }
1513 
1514  value =
1515  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1516 
1517  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1518  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1519  roundingMode);
1520 
1521  // Move to the new zero-point.
1522  value =
1523  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1524 
1525  // Saturate to the output size.
1526  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1527  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1528 
1529  // Unsigned integers have a difference output value.
1530  if (op.getOutputUnsigned()) {
1531  intMin = 0;
1532  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1533  }
1534 
1535  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1536  loc, nestedBuilder.getI32IntegerAttr(intMin));
1537  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1538  loc, nestedBuilder.getI32IntegerAttr(intMax));
1539 
1540  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1541  nestedBuilder, /*isUnsigned=*/false);
1542 
1543  if (outIntType.getWidth() < 32) {
1544  value = nestedBuilder.create<arith::TruncIOp>(
1545  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1546  value);
1547  }
1548 
1549  if (outIntType.isUnsignedInteger()) {
1550  value = nestedBuilder
1551  .create<UnrealizedConversionCastOp>(nestedLoc,
1552  outIntType, value)
1553  .getResult(0);
1554  }
1555  nestedBuilder.create<linalg::YieldOp>(loc, value);
1556  });
1557 
1558  rewriter.replaceOp(op, linalgOp->getResults());
1559  return success();
1560  }
1561 };
1562 
1563 // Handle the resize case where the input is a 1x1 image. This case
1564 // can entirely avoiding having extract operations which target much
1565 // more difficult to optimize away.
1566 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1567 public:
1569 
1570  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1571  PatternRewriter &rewriter) const final {
1572  Location loc = op.getLoc();
1573  ImplicitLocOpBuilder builder(loc, rewriter);
1574  auto input = op.getInput();
1575  auto inputTy = cast<RankedTensorType>(input.getType());
1576  auto resultTy = cast<RankedTensorType>(op.getType());
1577  const bool isBilinear = op.getMode() == "BILINEAR";
1578 
1579  auto inputH = inputTy.getDimSize(1);
1580  auto inputW = inputTy.getDimSize(2);
1581  auto outputH = resultTy.getDimSize(1);
1582  auto outputW = resultTy.getDimSize(2);
1583 
1584  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1585  return rewriter.notifyMatchFailure(
1586  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1587 
1588  // TODO(suderman): These string values should be declared the TOSA dialect.
1589  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1590  return rewriter.notifyMatchFailure(
1591  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1592 
1593  if (inputTy == resultTy) {
1594  rewriter.replaceOp(op, input);
1595  return success();
1596  }
1597 
1598  SmallVector<int64_t> scale;
1599  if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
1600  return failure();
1601  }
1602 
1603  // Collapse the unit width and height away.
1604  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1605  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1606  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1607  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1608  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1609 
1610  auto collapseTy =
1611  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1612  inputTy.getElementType());
1613  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1614  reassociationMap);
1615 
1616  // Get any dynamic shapes that appear in the input format.
1617  llvm::SmallVector<Value> outputDynSize;
1618  if (inputTy.isDynamicDim(0))
1619  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1620  if (inputTy.isDynamicDim(3))
1621  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1622 
1623  // Generate the elementwise operation for casting scaling the input value.
1624  auto genericTy = collapseTy.clone(resultTy.getElementType());
1625  Value empty = builder.create<tensor::EmptyOp>(
1626  genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1627  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1628  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1629  utils::IteratorType::parallel);
1630 
1631  auto generic = builder.create<linalg::GenericOp>(
1632  genericTy, ValueRange{collapse}, ValueRange{empty},
1633  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1634  [=](OpBuilder &b, Location loc, ValueRange args) {
1635  Value value = args[0];
1636  // This is the quantized case.
1637  if (inputTy.getElementType() != resultTy.getElementType()) {
1638  value =
1639  b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1640 
1641  if (isBilinear && scale[0] != 0) {
1642  Value scaleY = b.create<arith::ConstantOp>(
1643  loc, b.getI32IntegerAttr(scale[0]));
1644  value = b.create<arith::MulIOp>(loc, value, scaleY);
1645  }
1646 
1647  if (isBilinear && scale[2] != 0) {
1648  Value scaleX = b.create<arith::ConstantOp>(
1649  loc, b.getI32IntegerAttr(scale[2]));
1650  value = b.create<arith::MulIOp>(loc, value, scaleX);
1651  }
1652  }
1653 
1654  b.create<linalg::YieldOp>(loc, value);
1655  });
1656 
1657  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1658  op, resultTy, generic.getResults()[0], reassociationMap);
1659  return success();
1660  }
1661 };
1662 
1663 // TOSA resize with width or height of 1 may be broadcasted to a wider
1664 // dimension. This is done by materializing a new tosa.resize without
1665 // the broadcasting behavior, and an explicit broadcast afterwards.
1666 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1667 public:
1669 
1670  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1671  PatternRewriter &rewriter) const final {
1672  Location loc = op.getLoc();
1673  ImplicitLocOpBuilder builder(loc, rewriter);
1674  auto input = op.getInput();
1675  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1676  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1677 
1678  if (!inputTy || !resultTy)
1679  return rewriter.notifyMatchFailure(op,
1680  "requires ranked input/output types");
1681 
1682  auto batch = inputTy.getDimSize(0);
1683  auto channels = inputTy.getDimSize(3);
1684  auto inputH = inputTy.getDimSize(1);
1685  auto inputW = inputTy.getDimSize(2);
1686  auto outputH = resultTy.getDimSize(1);
1687  auto outputW = resultTy.getDimSize(2);
1688 
1689  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1690  return rewriter.notifyMatchFailure(
1691  op, "tosa.resize has no broadcasting behavior");
1692 
1693  // For any dimension that is broadcastable we generate a width of 1
1694  // on the output.
1695  llvm::SmallVector<int64_t> resizeShape;
1696  resizeShape.push_back(batch);
1697  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1698  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1699  resizeShape.push_back(channels);
1700 
1701  auto resizeTy = resultTy.clone(resizeShape);
1702  auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
1703  op.getOffset(), op.getBorder(),
1704  op.getMode());
1705 
1706  // Collapse an unit result dims.
1707  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1708  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1709  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1710  if (inputH != 1)
1711  reassociationMap.push_back({});
1712  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1713  if (inputW != 1)
1714  reassociationMap.push_back({});
1715  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1716 
1717  llvm::SmallVector<int64_t> collapseShape = {batch};
1718  if (inputH != 1)
1719  collapseShape.push_back(outputH);
1720  if (inputW != 1)
1721  collapseShape.push_back(outputW);
1722  collapseShape.push_back(channels);
1723 
1724  auto collapseTy = resultTy.clone(collapseShape);
1725  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1726  reassociationMap);
1727 
1728  // Broadcast the collapsed shape to the output result.
1729  llvm::SmallVector<Value> outputDynSize;
1730  if (inputTy.isDynamicDim(0))
1731  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1732  if (inputTy.isDynamicDim(3))
1733  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1734 
1735  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1736  utils::IteratorType::parallel);
1737  Value empty = builder.create<tensor::EmptyOp>(
1738  resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1739 
1740  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1741  if (inputH != 1)
1742  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1743  if (inputW != 1)
1744  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1745  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1746 
1747  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1748  inputExprs, rewriter.getContext());
1749 
1750  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1751  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1752  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1753  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1754  [=](OpBuilder &b, Location loc, ValueRange args) {
1755  Value value = args[0];
1756  b.create<linalg::YieldOp>(loc, value);
1757  });
1758 
1759  return success();
1760  }
1761 };
1762 
1763 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1764 public:
1766 
1767  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1768  PatternRewriter &rewriter) const final {
1769  Location loc = op.getLoc();
1770  ImplicitLocOpBuilder b(loc, rewriter);
1771  auto input = op.getInput();
1772  auto inputTy = cast<ShapedType>(input.getType());
1773  auto resultTy = cast<ShapedType>(op.getType());
1774  auto resultETy = resultTy.getElementType();
1775 
1776  bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1777  auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1778 
1779  auto imageH = inputTy.getShape()[1];
1780  auto imageW = inputTy.getShape()[2];
1781 
1782  auto dynamicDimsOr =
1783  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1784  if (!dynamicDimsOr.has_value())
1785  return rewriter.notifyMatchFailure(
1786  op, "unable to get dynamic dimensions of tosa.resize");
1787 
1788  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1789  return rewriter.notifyMatchFailure(
1790  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1791 
1792  SmallVector<AffineMap, 2> affineMaps = {
1793  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1794  auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1795  *dynamicDimsOr);
1796  auto genericOp = b.create<linalg::GenericOp>(
1797  resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1798  getNParallelLoopsAttrs(resultTy.getRank()));
1799  Value resize = genericOp.getResult(0);
1800 
1801  {
1802  OpBuilder::InsertionGuard regionGuard(b);
1803  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1804  TypeRange({resultETy}), loc);
1805  Value batch = b.create<linalg::IndexOp>(0);
1806  Value y = b.create<linalg::IndexOp>(1);
1807  Value x = b.create<linalg::IndexOp>(2);
1808  Value channel = b.create<linalg::IndexOp>(3);
1809 
1810  Value zeroI32 =
1811  b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1812  Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1813  Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1814  Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1815 
1816  Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1817  Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1818 
1819  SmallVector<int64_t> scale, offset, border;
1820  if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
1821  !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
1822  !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
1823  return rewriter.notifyMatchFailure(
1824  op, "tosa.resize scale/offset/border should have compile time "
1825  "constant values.");
1826  }
1827 
1828  Value yScaleN, yScaleD, xScaleN, xScaleD;
1829  yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1830  yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1831  xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1832  xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1833 
1834  Value yOffset, xOffset, yBorder, xBorder;
1835  yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1836  xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1837  yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1838  xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1839 
1840  // Compute the ix and dx values for both the X and Y dimensions.
1841  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1842  Value scaleN, Value scaleD, Value offset,
1843  int size, ImplicitLocOpBuilder &b) {
1844  if (size == 1) {
1845  index = zeroI32;
1846  delta = zeroFp;
1847  return;
1848  }
1849  // x = x * scale_d + offset;
1850  // ix = floor(x / scale_n)
1851  Value val = b.create<arith::MulIOp>(in, scaleD);
1852  val = b.create<arith::AddIOp>(val, offset);
1853  index = b.create<arith::FloorDivSIOp>(val, scaleN);
1854 
1855  // rx = x % scale_n
1856  // dx = rx / scale_n
1857  Value r = b.create<arith::RemSIOp>(val, scaleN);
1858  Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1859  Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1860  delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1861  };
1862 
1863  // Compute the ix and dx values for the X and Y dimensions - int case.
1864  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1865  Value scaleN, Value scaleD, Value offset,
1866  int size, ImplicitLocOpBuilder &b) {
1867  if (size == 1) {
1868  index = zeroI32;
1869  delta = zeroI32;
1870  return;
1871  }
1872  // x = x * scale_d + offset;
1873  // ix = floor(x / scale_n)
1874  // dx = x - ix * scale_n;
1875  Value val = b.create<arith::MulIOp>(in, scaleD);
1876  val = b.create<arith::AddIOp>(val, offset);
1877  index = b.create<arith::DivSIOp>(val, scaleN);
1878  delta = b.create<arith::MulIOp>(index, scaleN);
1879  delta = b.create<arith::SubIOp>(val, delta);
1880  };
1881 
1882  Value ix, iy, dx, dy;
1883  if (floatingPointMode) {
1884  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1885  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1886  } else {
1887  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1888  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1889  }
1890 
1891  if (op.getMode() == "NEAREST_NEIGHBOR") {
1892  auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1893 
1894  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1895  Value max, int size,
1896  ImplicitLocOpBuilder &b) -> Value {
1897  if (size == 1) {
1898  return b.create<arith::ConstantIndexOp>(0);
1899  }
1900 
1901  Value pred;
1902  if (floatingPointMode) {
1903  auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1904  pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1905  } else {
1906  Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1907  pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1908  dvalDouble, scale);
1909  }
1910 
1911  auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1912  val = b.create<arith::AddIOp>(val, offset);
1913  val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
1914  return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1915  };
1916 
1917  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1918  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1919 
1920  Value result = b.create<tensor::ExtractOp>(
1921  input, ValueRange{batch, iy, ix, channel});
1922 
1923  b.create<linalg::YieldOp>(result);
1924  } else {
1925  // The mode here must be BILINEAR.
1926  assert(op.getMode() == "BILINEAR");
1927 
1928  auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1929 
1930  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1932  val0 = in;
1933  val1 = b.create<arith::AddIOp>(val0, oneVal);
1934  val0 =
1935  clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1936  val1 =
1937  clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
1938  val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1939  val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1940  };
1941 
1942  // Linalg equivalent to the section below:
1943  // int16_t iy0 = apply_max(iy, 0);
1944  // int16_t iy1 = apply_min(iy + 1, IH - 1);
1945  // int16_t ix0 = apply_max(ix, 0);
1946  // int16_t ix1 = apply_min(ix + 1, IW - 1);
1947  Value x0, x1, y0, y1;
1948  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1949  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1950 
1951  Value y0x0 = b.create<tensor::ExtractOp>(
1952  input, ValueRange{batch, y0, x0, channel});
1953  Value y0x1 = b.create<tensor::ExtractOp>(
1954  input, ValueRange{batch, y0, x1, channel});
1955  Value y1x0 = b.create<tensor::ExtractOp>(
1956  input, ValueRange{batch, y1, x0, channel});
1957  Value y1x1 = b.create<tensor::ExtractOp>(
1958  input, ValueRange{batch, y1, x1, channel});
1959 
1960  if (floatingPointMode) {
1961  auto oneVal =
1962  b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1963  auto interpolate = [&](Value val0, Value val1, Value delta,
1964  int inputSize,
1965  ImplicitLocOpBuilder &b) -> Value {
1966  if (inputSize == 1)
1967  return val0;
1968  Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1969  Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1970  Value mul1 = b.create<arith::MulFOp>(val1, delta);
1971  return b.create<arith::AddFOp>(mul0, mul1);
1972  };
1973 
1974  // Linalg equivalent to the section below:
1975  // topAcc = v00 * (unit_x - dx);
1976  // topAcc += v01 * dx;
1977  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1978 
1979  // Linalg equivalent to the section below:
1980  // bottomAcc = v10 * (unit_x - dx);
1981  // bottomAcc += v11 * dx;
1982  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1983 
1984  // Linalg equivalent to the section below:
1985  // result = topAcc * (unit_y - dy) + bottomAcc * dy
1986  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1987  b.create<linalg::YieldOp>(result);
1988  } else {
1989  // Perform in quantized space.
1990  y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1991  y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1992  y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1993  y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1994 
1995  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1996  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1997  dx = b.create<arith::ExtSIOp>(resultETy, dx);
1998  dy = b.create<arith::ExtSIOp>(resultETy, dy);
1999  }
2000 
2001  Value yScaleNExt = yScaleN;
2002  Value xScaleNExt = xScaleN;
2003 
2004  const int64_t scaleBitwidth =
2005  xScaleN.getType().getIntOrFloatBitWidth();
2006  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2007  yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
2008  xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
2009  }
2010 
2011  auto interpolate = [](Value val0, Value val1, Value weight1,
2012  Value scale, int inputSize,
2013  ImplicitLocOpBuilder &b) -> Value {
2014  if (inputSize == 1)
2015  return b.create<arith::MulIOp>(val0, scale);
2016  Value weight0 = b.create<arith::SubIOp>(scale, weight1);
2017  Value mul0 = b.create<arith::MulIOp>(val0, weight0);
2018  Value mul1 = b.create<arith::MulIOp>(val1, weight1);
2019  return b.create<arith::AddIOp>(mul0, mul1);
2020  };
2021 
2022  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2023  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2024  Value result =
2025  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2026  b.create<linalg::YieldOp>(result);
2027  }
2028  }
2029  }
2030 
2031  rewriter.replaceOp(op, resize);
2032  return success();
2033  }
2034 };
2035 
2036 // At the codegen level any identity operations should be removed. Any cases
2037 // where identity is load-bearing (e.g. cross device computation) should be
2038 // handled before lowering to codegen.
2039 template <typename SrcOp>
2040 class IdentityNConverter : public OpRewritePattern<SrcOp> {
2041 public:
2043 
2044  LogicalResult matchAndRewrite(SrcOp op,
2045  PatternRewriter &rewriter) const final {
2046  rewriter.replaceOp(op, op.getOperation()->getOperands());
2047  return success();
2048  }
2049 };
2050 
2051 template <typename SrcOp>
2052 class ReduceConverter : public OpRewritePattern<SrcOp> {
2053 public:
2055 
2056  LogicalResult matchAndRewrite(SrcOp reduceOp,
2057  PatternRewriter &rewriter) const final {
2058  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2059  }
2060 };
2061 
2062 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2063 public:
2065 
2066  LogicalResult matchAndRewrite(tosa::ReverseOp op,
2067  PatternRewriter &rewriter) const final {
2068  auto loc = op.getLoc();
2069  Value input = op.getInput1();
2070  auto inputTy = cast<ShapedType>(input.getType());
2071  auto resultTy = cast<ShapedType>(op.getType());
2072  auto axis = op.getAxis();
2073 
2074  SmallVector<Value> dynDims;
2075  for (int i = 0; i < inputTy.getRank(); i++) {
2076  if (inputTy.isDynamicDim(i)) {
2077  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2078  }
2079  }
2080 
2081  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
2082 
2083  // First fill the output buffer with the init value.
2084  auto emptyTensor = rewriter
2085  .create<tensor::EmptyOp>(loc, inputTy.getShape(),
2086  inputTy.getElementType(),
2087  ArrayRef<Value>({dynDims}))
2088  .getResult();
2089  SmallVector<AffineMap, 2> affineMaps = {
2090  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2091 
2092  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2093  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
2094  getNParallelLoopsAttrs(resultTy.getRank()),
2095  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2096  llvm::SmallVector<Value> indices;
2097  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2098  Value index =
2099  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
2100  if (i == axis) {
2101  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
2102  auto sizeMinusOne =
2103  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
2104  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
2105  index);
2106  }
2107 
2108  indices.push_back(index);
2109  }
2110 
2111  auto extract = nestedBuilder.create<tensor::ExtractOp>(
2112  nestedLoc, input, indices);
2113  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
2114  extract.getResult());
2115  });
2116  return success();
2117  }
2118 };
2119 
2120 // This converter translate a tile operation to a reshape, broadcast, reshape.
2121 // The first reshape minimally expands each tiled dimension to include a
2122 // proceding size-1 dim. This dim is then broadcasted to the appropriate
2123 // multiple.
2124 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2126 
2127  LogicalResult
2128  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2129  ConversionPatternRewriter &rewriter) const override {
2130  auto loc = op.getLoc();
2131  auto input = op.getInput1();
2132  auto inputTy = cast<ShapedType>(input.getType());
2133  auto inputShape = inputTy.getShape();
2134  auto resultTy = cast<ShapedType>(op.getType());
2135  auto elementTy = inputTy.getElementType();
2136  int64_t rank = inputTy.getRank();
2137 
2138  SmallVector<int64_t> multiples;
2139  if (failed(op.getConstantMultiples(multiples)))
2140  return failure();
2141 
2142  // Broadcast the newly added dimensions to their appropriate multiple.
2143  SmallVector<int64_t, 2> genericShape;
2144  for (int i = 0; i < rank; i++) {
2145  int64_t dim = multiples[i];
2146  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2147  genericShape.push_back(inputShape[i]);
2148  }
2149 
2150  SmallVector<Value> dynDims;
2151  for (int i = 0; i < inputTy.getRank(); i++) {
2152  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2153  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2154  }
2155  }
2156 
2157  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
2158  op.getLoc(), genericShape, elementTy, dynDims);
2159 
2160  // We needs to map the input shape to the non-broadcasted dimensions.
2161  SmallVector<AffineExpr, 4> dimExprs;
2162  dimExprs.reserve(rank);
2163  for (unsigned i = 0; i < rank; ++i)
2164  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2165 
2166  auto readAffineMap =
2167  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2168  rewriter.getContext());
2169 
2170  SmallVector<AffineMap, 2> affineMaps = {
2171  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2172 
2173  auto genericOp = rewriter.create<linalg::GenericOp>(
2174  loc, RankedTensorType::get(genericShape, elementTy), input,
2175  ValueRange{emptyTensor}, affineMaps,
2176  getNParallelLoopsAttrs(genericShape.size()),
2177  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2178  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
2179  });
2180 
2181  auto shapeValue = getTosaConstShape(
2182  rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
2183  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2184  op, resultTy, genericOp.getResult(0), shapeValue);
2185  return success();
2186  }
2187 };
2188 
2189 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2190 // op, producing two output buffers.
2191 //
2192 // The first output buffer contains the index of the found maximum value. It is
2193 // initialized to 0 and is resulting integer type.
2194 //
2195 // The second output buffer contains the maximum value found. It is initialized
2196 // to the minimum representable value of the input element type. After being
2197 // populated by indexed_generic, this buffer is disgarded as only the index is
2198 // requested.
2199 //
2200 // The indexed_generic op updates both the maximum value and index if the
2201 // current value exceeds the running max.
2202 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2203 public:
2205 
2206  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2207  PatternRewriter &rewriter) const final {
2208  auto loc = argmaxOp.getLoc();
2209  Value input = argmaxOp.getInput();
2210  auto inputTy = cast<ShapedType>(input.getType());
2211  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2212  auto inElementTy = inputTy.getElementType();
2213  auto outElementTy = resultTy.getElementType();
2214  int axis = argmaxOp.getAxis();
2215  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2216 
2217  if (!isa<IntegerType>(outElementTy))
2218  return rewriter.notifyMatchFailure(
2219  argmaxOp,
2220  "tosa.arg_max to linalg.* requires integer-like result type");
2221 
2222  SmallVector<Value> dynDims;
2223  for (int i = 0; i < inputTy.getRank(); i++) {
2224  if (inputTy.isDynamicDim(i) && i != axis) {
2225  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2226  }
2227  }
2228 
2229  // First fill the output buffer for the index.
2230  auto emptyTensorIdx = rewriter
2231  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2232  outElementTy, dynDims)
2233  .getResult();
2234  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
2235  loc, rewriter.getIntegerAttr(outElementTy, 0));
2236  auto filledTensorIdx =
2237  rewriter
2238  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
2239  ValueRange{emptyTensorIdx})
2240  .result();
2241 
2242  // Second fill the output buffer for the running max.
2243  auto emptyTensorMax = rewriter
2244  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2245  inElementTy, dynDims)
2246  .getResult();
2247  auto fillValueMaxAttr =
2248  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2249 
2250  if (!fillValueMaxAttr)
2251  return rewriter.notifyMatchFailure(
2252  argmaxOp, "unsupported tosa.argmax element type");
2253 
2254  auto fillValueMax =
2255  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2256  auto filledTensorMax =
2257  rewriter
2258  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2259  ValueRange{emptyTensorMax})
2260  .result();
2261 
2262  // We need to reduce along the arg-max axis, with parallel operations along
2263  // the rest.
2265  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2266  iteratorTypes[axis] = utils::IteratorType::reduction;
2267 
2268  SmallVector<AffineExpr, 2> srcExprs;
2269  SmallVector<AffineExpr, 2> dstExprs;
2270  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2271  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2272  if (axis != i)
2273  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2274  }
2275 
2276  bool didEncounterError = false;
2277  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2278  rewriter.getContext());
2279  auto linalgOp = rewriter.create<linalg::GenericOp>(
2280  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2281  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2282  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2283  ValueRange blockArgs) {
2284  auto newValue = blockArgs[0];
2285  auto oldIndex = blockArgs[1];
2286  auto oldValue = blockArgs[2];
2287 
2288  Value newIndex = rewriter.create<arith::IndexCastOp>(
2289  nestedLoc, oldIndex.getType(),
2290  rewriter.create<linalg::IndexOp>(loc, axis));
2291 
2292  Value predicate;
2293  if (isa<FloatType>(inElementTy)) {
2294  if (argmaxOp.getNanMode() == "IGNORE") {
2295  // Only update index & max value for non NaN values. If all
2296  // values are NaNs, the initial index will be return which is 0.
2297  predicate = rewriter.create<arith::CmpFOp>(
2298  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2299  } else {
2300  // Update max value if either of the following is true:
2301  // - new value is bigger
2302  // - cur max is not NaN and new value is NaN
2303  Value gt = rewriter.create<arith::CmpFOp>(
2304  nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
2305  Value oldNonNaN = rewriter.create<arith::CmpFOp>(
2306  nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
2307  predicate = rewriter.create<arith::AndIOp>(
2308  nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2309  }
2310  } else if (isa<IntegerType>(inElementTy)) {
2311  predicate = rewriter.create<arith::CmpIOp>(
2312  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2313  } else {
2314  didEncounterError = true;
2315  return;
2316  }
2317 
2318  auto resultMax = rewriter.create<arith::SelectOp>(
2319  nestedLoc, predicate, newValue, oldValue);
2320  auto resultIndex = rewriter.create<arith::SelectOp>(
2321  nestedLoc, predicate, newIndex, oldIndex);
2322  nestedBuilder.create<linalg::YieldOp>(
2323  nestedLoc, ValueRange({resultIndex, resultMax}));
2324  });
2325 
2326  if (didEncounterError)
2327  return rewriter.notifyMatchFailure(
2328  argmaxOp, "unsupported tosa.argmax element type");
2329 
2330  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2331  return success();
2332  }
2333 };
2334 
2335 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2336 public:
2338  LogicalResult
2339  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2340  ConversionPatternRewriter &rewriter) const final {
2341  auto input = adaptor.getOperands()[0];
2342  auto indices = adaptor.getOperands()[1];
2343 
2344  auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2345  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2346  if (!valuesTy || !resultTy)
2347  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2348 
2349  auto dynamicDims = inferDynamicDimsForGather(
2350  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2351 
2352  auto resultElementTy = resultTy.getElementType();
2353 
2354  auto loc = op.getLoc();
2355  auto emptyTensor =
2356  rewriter
2357  .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2358  dynamicDims)
2359  .getResult();
2360 
2361  SmallVector<AffineMap, 2> affineMaps = {
2363  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2364  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2365  rewriter.getContext()),
2366  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2367 
2368  auto genericOp = rewriter.create<linalg::GenericOp>(
2369  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2370  ValueRange{emptyTensor}, affineMaps,
2371  getNParallelLoopsAttrs(resultTy.getRank()),
2372  [&](OpBuilder &b, Location loc, ValueRange args) {
2373  auto indexValue = args[0];
2374  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2375  Value index1 = rewriter.create<arith::IndexCastOp>(
2376  loc, rewriter.getIndexType(), indexValue);
2377  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2378  Value extract = rewriter.create<tensor::ExtractOp>(
2379  loc, input, ValueRange{index0, index1, index2});
2380  rewriter.create<linalg::YieldOp>(loc, extract);
2381  });
2382  rewriter.replaceOp(op, genericOp.getResult(0));
2383  return success();
2384  }
2385 
2386  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2387  Location loc,
2388  Value values,
2389  Value indices) {
2390  llvm::SmallVector<Value> results;
2391 
2392  auto addDynamicDimension = [&](Value source, int64_t dim) {
2393  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2394  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2395  results.push_back(dimValue);
2396  };
2397 
2398  addDynamicDimension(values, 0);
2399  addDynamicDimension(indices, 1);
2400  addDynamicDimension(values, 2);
2401  return results;
2402  }
2403 };
2404 
2405 // Lowerings the TableOp to a series of gathers and numerica operations. This
2406 // includes interpolation between the high/low values. For the I8 varient, this
2407 // simplifies to a single gather operation.
2408 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2409 public:
2411 
2412  LogicalResult matchAndRewrite(tosa::TableOp op,
2413  PatternRewriter &rewriter) const final {
2414  auto loc = op.getLoc();
2415  Value input = op.getInput1();
2416  Value table = op.getTable();
2417  auto inputTy = cast<ShapedType>(input.getType());
2418  auto tableTy = cast<ShapedType>(table.getType());
2419  auto resultTy = cast<ShapedType>(op.getType());
2420 
2421  auto inputElementTy = inputTy.getElementType();
2422  auto tableElementTy = tableTy.getElementType();
2423  auto resultElementTy = resultTy.getElementType();
2424 
2425  SmallVector<Value> dynDims;
2426  for (int i = 0; i < resultTy.getRank(); ++i) {
2427  if (inputTy.isDynamicDim(i)) {
2428  dynDims.push_back(
2429  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2430  }
2431  }
2432 
2433  auto emptyTensor = rewriter
2434  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2435  resultElementTy, dynDims)
2436  .getResult();
2437 
2438  SmallVector<AffineMap, 2> affineMaps = {
2439  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2440  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2441 
2442  auto genericOp = rewriter.create<linalg::GenericOp>(
2443  loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2444  getNParallelLoopsAttrs(resultTy.getRank()));
2445  rewriter.replaceOp(op, genericOp.getResult(0));
2446 
2447  {
2448  OpBuilder::InsertionGuard regionGuard(rewriter);
2449  Block *block = rewriter.createBlock(
2450  &genericOp.getRegion(), genericOp.getRegion().end(),
2451  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2452 
2453  auto inputValue = block->getArgument(0);
2454  rewriter.setInsertionPointToStart(block);
2455  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2456  resultElementTy.isInteger(8)) {
2457  Value index = rewriter.create<arith::IndexCastOp>(
2458  loc, rewriter.getIndexType(), inputValue);
2459  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2460  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2461  index, offset);
2462  Value extract =
2463  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2464  rewriter.create<linalg::YieldOp>(loc, extract);
2465  return success();
2466  }
2467 
2468  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2469  resultElementTy.isInteger(32)) {
2470  Value extend = rewriter.create<arith::ExtSIOp>(
2471  loc, rewriter.getI32Type(), inputValue);
2472 
2473  auto offset = rewriter.create<arith::ConstantOp>(
2474  loc, rewriter.getI32IntegerAttr(32768));
2475  auto seven = rewriter.create<arith::ConstantOp>(
2476  loc, rewriter.getI32IntegerAttr(7));
2477  auto one = rewriter.create<arith::ConstantOp>(
2478  loc, rewriter.getI32IntegerAttr(1));
2479  auto b1111111 = rewriter.create<arith::ConstantOp>(
2480  loc, rewriter.getI32IntegerAttr(127));
2481 
2482  // Compute the index and fractional part from the input value:
2483  // value = value + 32768
2484  // index = value >> 7;
2485  // fraction = 0x01111111 & value
2486  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2487  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2488  Value fraction =
2489  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2490 
2491  // Extract the base and next values from the table.
2492  // base = (int32_t) table[index];
2493  // next = (int32_t) table[index + 1];
2494  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2495 
2496  index = rewriter.create<arith::IndexCastOp>(
2497  loc, rewriter.getIndexType(), index);
2498  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2499  loc, rewriter.getIndexType(), indexPlusOne);
2500 
2501  Value base =
2502  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2503  Value next = rewriter.create<tensor::ExtractOp>(
2504  loc, table, ValueRange{indexPlusOne});
2505 
2506  base =
2507  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2508  next =
2509  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2510 
2511  // Use the fractional part to interpolate between the input values:
2512  // result = (base << 7) + (next - base) * fraction
2513  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2514  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2515  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2516  Value result =
2517  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2518 
2519  rewriter.create<linalg::YieldOp>(loc, result);
2520 
2521  return success();
2522  }
2523  }
2524 
2525  return rewriter.notifyMatchFailure(
2526  op, "unable to create body for tosa.table op");
2527  }
2528 };
2529 
2530 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2532 
2533  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2534 
2535  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2536  OpFoldResult ofr) {
2537  auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2538  auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2539 
2540  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2541  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2542  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2543  return getAsOpFoldResult(plusOne);
2544  }
2545 
2546  static RankedTensorType
2547  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2548  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2549  // Get [N, H, W]
2550  auto dims = tensor::getMixedSizes(builder, loc, input);
2551 
2552  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2553  // output tensors.
2554  dims[2] = halfPlusOne(builder, loc, dims[2]);
2555 
2556  llvm::SmallVector<int64_t, 3> staticSizes;
2557  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2558 
2559  auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2560  return RankedTensorType::get(staticSizes, elementType);
2561  }
2562 
2563  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2564  RankedTensorType type,
2565  llvm::ArrayRef<Value> dynamicSizes) {
2566  auto emptyTensor =
2567  rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2568  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2569  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2570  auto filledTensor = rewriter
2571  .create<linalg::FillOp>(loc, ValueRange{fillValue},
2572  ValueRange{emptyTensor})
2573  .result();
2574  return filledTensor;
2575  }
2576 
2577  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2578  FloatType type, Value value) {
2579  auto integerVal = builder.create<arith::IndexCastUIOp>(
2580  loc,
2581  type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2582  : builder.getI32Type(),
2583  value);
2584 
2585  return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2586  }
2587 
2588  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2589  FloatType type, int64_t index) {
2590  auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2591  return castIndexToFloat(builder, loc, type, indexVal);
2592  }
2593 
2594  template <typename... Args>
2595  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2596  Args... args) {
2597  return {builder.getAffineDimExpr(args)...};
2598  }
2599 
2600  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2601  PatternRewriter &rewriter) const override {
2602  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2603  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2604  return rewriter.notifyMatchFailure(rfft2d,
2605  "only supports ranked tensors");
2606  }
2607 
2608  auto loc = rfft2d.getLoc();
2609  auto input = rfft2d.getInputReal();
2610  auto elementType =
2611  dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2612  if (!elementType)
2613  return rewriter.notifyMatchFailure(rfft2d,
2614  "only supports float element types");
2615 
2616  // Compute the output type and set of dynamic sizes
2617  llvm::SmallVector<Value> dynamicSizes;
2618  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2619 
2620  // Iterator types for the linalg.generic implementation
2622  utils::IteratorType::parallel, utils::IteratorType::parallel,
2623  utils::IteratorType::parallel, utils::IteratorType::reduction,
2624  utils::IteratorType::reduction};
2625 
2626  // Inputs/outputs to the linalg.generic implementation
2627  llvm::SmallVector<Value> genericOpInputs = {input};
2628  llvm::SmallVector<Value> genericOpOutputs = {
2629  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2630  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2631 
2632  // Indexing maps for input and output tensors
2633  auto indexingMaps = AffineMap::inferFromExprList(
2634  llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2635  affineDimsExpr(rewriter, 0, 1, 2),
2636  affineDimsExpr(rewriter, 0, 1, 2)},
2637  rewriter.getContext());
2638 
2639  // Width and height dimensions of the original input.
2640  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2641  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2642 
2643  // Constants and dimension sizes
2644  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2645  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2646  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2647  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2648 
2649  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2650  Value valReal = args[0];
2651  Value sumReal = args[1];
2652  Value sumImag = args[2];
2653 
2654  // Indices for angle computation
2655  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2656  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2657  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2658  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2659 
2660  // Calculating angle without integer parts of components as sin/cos are
2661  // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2662  // / W);
2663  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2664  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2665 
2666  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2667  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2668 
2669  auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2670  auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2671 
2672  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2673  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2674  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2675  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2676 
2677  // realComponent = valReal * cos(angle)
2678  // imagComponent = valReal * sin(angle)
2679  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2680  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2681  auto realComponent =
2682  builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2683  auto imagComponent =
2684  builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2685 
2686  // outReal = sumReal + realComponent
2687  // outImag = sumImag - imagComponent
2688  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2689  auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2690 
2691  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2692  };
2693 
2694  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2695  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2696  indexingMaps, iteratorTypes, buildBody);
2697 
2698  return success();
2699  }
2700 };
2701 
2702 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2704 
2705  LogicalResult matchAndRewrite(FFT2dOp fft2d,
2706  PatternRewriter &rewriter) const override {
2707  if (!llvm::all_of(fft2d->getOperandTypes(),
2708  RFFT2dConverter::isRankedTensor) ||
2709  !llvm::all_of(fft2d->getResultTypes(),
2710  RFFT2dConverter::isRankedTensor)) {
2711  return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2712  }
2713 
2714  Location loc = fft2d.getLoc();
2715  Value input_real = fft2d.getInputReal();
2716  Value input_imag = fft2d.getInputImag();
2717  BoolAttr inverse = fft2d.getInverseAttr();
2718 
2719  auto real_el_ty = cast<FloatType>(
2720  cast<ShapedType>(input_real.getType()).getElementType());
2721  [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2722  cast<ShapedType>(input_imag.getType()).getElementType());
2723 
2724  assert(real_el_ty == imag_el_ty);
2725 
2726  // Compute the output type and set of dynamic sizes
2727  SmallVector<Value> dynamicSizes;
2728 
2729  // Get [N, H, W]
2730  auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2731 
2732  SmallVector<int64_t, 3> staticSizes;
2733  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2734 
2735  auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2736 
2737  // Iterator types for the linalg.generic implementation
2738  SmallVector<utils::IteratorType, 5> iteratorTypes = {
2739  utils::IteratorType::parallel, utils::IteratorType::parallel,
2740  utils::IteratorType::parallel, utils::IteratorType::reduction,
2741  utils::IteratorType::reduction};
2742 
2743  // Inputs/outputs to the linalg.generic implementation
2744  SmallVector<Value> genericOpInputs = {input_real, input_imag};
2745  SmallVector<Value> genericOpOutputs = {
2746  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2747  dynamicSizes),
2748  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2749  dynamicSizes)};
2750 
2751  // Indexing maps for input and output tensors
2752  auto indexingMaps = AffineMap::inferFromExprList(
2753  ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2754  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2755  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2756  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2757  rewriter.getContext());
2758 
2759  // Width and height dimensions of the original input.
2760  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2761  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2762 
2763  // Constants and dimension sizes
2764  auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2765  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2766  Value constH =
2767  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2768  Value constW =
2769  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2770 
2771  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2772  Value valReal = args[0];
2773  Value valImag = args[1];
2774  Value sumReal = args[2];
2775  Value sumImag = args[3];
2776 
2777  // Indices for angle computation
2778  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2779  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2780  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2781  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2782 
2783  // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2784  // ox) % W ) / W);
2785  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2786  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2787 
2788  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2789  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2790 
2791  auto iyRemFloat =
2792  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2793  auto ixRemFloat =
2794  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2795 
2796  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2797  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2798 
2799  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2800  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2801 
2802  if (inverse.getValue()) {
2803  angle = builder.create<arith::MulFOp>(
2804  loc, angle,
2805  rewriter.create<arith::ConstantOp>(
2806  loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2807  }
2808 
2809  // realComponent = val_real * cos(a) + val_imag * sin(a);
2810  // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2811  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2812  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2813 
2814  auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2815  auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2816  auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2817 
2818  auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2819  auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2820 
2821  auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2822 
2823  // outReal = sumReal + realComponent
2824  // outImag = sumImag - imagComponent
2825  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2826  auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2827 
2828  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2829  };
2830 
2831  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2832  fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2833  indexingMaps, iteratorTypes, buildBody);
2834 
2835  return success();
2836  }
2837 };
2838 
2839 } // namespace
2840 
2842  const TypeConverter &converter, RewritePatternSet *patterns) {
2843 
2844  // We have multiple resize coverters to handle degenerate cases.
2845  patterns->add<GenericResizeConverter>(patterns->getContext(),
2846  /*benefit=*/100);
2847  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2848  /*benefit=*/200);
2849  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2850  /*benefit=*/300);
2851 
2852  patterns->add<
2853  // clang-format off
2854  PointwiseConverter<tosa::AddOp>,
2855  PointwiseConverter<tosa::SubOp>,
2856  PointwiseConverter<tosa::MulOp>,
2857  PointwiseConverter<tosa::IntDivOp>,
2858  PointwiseConverter<tosa::NegateOp>,
2859  PointwiseConverter<tosa::PowOp>,
2860  PointwiseConverter<tosa::ReciprocalOp>,
2861  PointwiseConverter<tosa::RsqrtOp>,
2862  PointwiseConverter<tosa::LogOp>,
2863  PointwiseConverter<tosa::ExpOp>,
2864  PointwiseConverter<tosa::AbsOp>,
2865  PointwiseConverter<tosa::SinOp>,
2866  PointwiseConverter<tosa::CosOp>,
2867  PointwiseConverter<tosa::TanhOp>,
2868  PointwiseConverter<tosa::ErfOp>,
2869  PointwiseConverter<tosa::BitwiseAndOp>,
2870  PointwiseConverter<tosa::BitwiseOrOp>,
2871  PointwiseConverter<tosa::BitwiseNotOp>,
2872  PointwiseConverter<tosa::BitwiseXorOp>,
2873  PointwiseConverter<tosa::LogicalAndOp>,
2874  PointwiseConverter<tosa::LogicalNotOp>,
2875  PointwiseConverter<tosa::LogicalOrOp>,
2876  PointwiseConverter<tosa::LogicalXorOp>,
2877  PointwiseConverter<tosa::CastOp>,
2878  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2879  PointwiseConverter<tosa::LogicalRightShiftOp>,
2880  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2881  PointwiseConverter<tosa::ClzOp>,
2882  PointwiseConverter<tosa::SelectOp>,
2883  PointwiseConverter<tosa::GreaterOp>,
2884  PointwiseConverter<tosa::GreaterEqualOp>,
2885  PointwiseConverter<tosa::EqualOp>,
2886  PointwiseConverter<tosa::MaximumOp>,
2887  PointwiseConverter<tosa::MinimumOp>,
2888  PointwiseConverter<tosa::CeilOp>,
2889  PointwiseConverter<tosa::FloorOp>,
2890  PointwiseConverter<tosa::ClampOp>,
2891  PointwiseConverter<tosa::SigmoidOp>
2892  >(converter, patterns->getContext());
2893 
2894  patterns->add<
2895  IdentityNConverter<tosa::IdentityOp>,
2896  ReduceConverter<tosa::ReduceAllOp>,
2897  ReduceConverter<tosa::ReduceAnyOp>,
2898  ReduceConverter<tosa::ReduceMinOp>,
2899  ReduceConverter<tosa::ReduceMaxOp>,
2900  ReduceConverter<tosa::ReduceSumOp>,
2901  ReduceConverter<tosa::ReduceProductOp>,
2902  ArgMaxConverter,
2903  GatherConverter,
2904  RescaleConverter,
2905  ReverseConverter,
2906  RFFT2dConverter,
2907  FFT2dConverter,
2908  TableConverter,
2909  TileConverter>(patterns->getContext());
2910  // clang-format on
2911 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static bool operandsAndResultsRanked(Operation *operation)
const float * table
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:198
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:226
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:385
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:252
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:370
IntegerType getI64Type()
Definition: Builders.cpp:67
IntegerType getI32Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:98
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:322
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:362
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:55
IndexType getIndexType()
Definition: Builders.cpp:53
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:64
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319