MLIR  21.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 
34 #include <numeric>
35 #include <type_traits>
36 
37 using namespace mlir;
38 using namespace mlir::tosa;
39 
40 // Helper function to materialize the semantically correct compare and select
41 // operations given a binary operation with a specific NaN propagation mode.
42 //
43 // In the case of "PROPAGATE" semantics no compare and selection is required and
44 // this function does nothing.
45 //
46 // In the case of "IGNORE" semantics this function materializes a comparison of
47 // the current operands to the op which will return true for any NaN
48 // argument and then selects between the non-NaN operation argument and the
49 // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
50 // code:
51 //
52 // In the case that the op is operating on non floating point types we ignore
53 // the attribute completely, this is consistent with the TOSA spec which has
54 // the following wording: "This attribute is ignored by non floating-point
55 // types."
56 //
57 // binary<op>(lhs, rhs):
58 // result = op(lhs, rhs)
59 // if lhs == NaN return rhs
60 // if rhs == NaN return lhs
61 // return result
62 template <typename OpTy>
63 static Value
65  Value lhs, Value rhs, Value result) {
66  // NaN propagation has no meaning for non floating point types.
67  if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
68  return result;
69 
70  auto nanMode = op.getNanMode();
71  if (nanMode == "PROPAGATE")
72  return result;
73 
74  // Unordered comparison of NaN against itself will always return true.
75  Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
76  op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
77  Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
78  op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
79  Value rhsOrResult =
80  rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
81  return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
82  rhsOrResult);
83 }
84 
86  Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
87  ConversionPatternRewriter &rewriter) {
88  Location loc = op->getLoc();
89  auto elementTy =
90  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
91 
92  // tosa::AbsOp
93  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
94  return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
95 
96  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
97  auto zero = rewriter.create<arith::ConstantOp>(
98  loc, rewriter.getZeroAttr(elementTy));
99  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
100  return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
101  }
102 
103  // tosa::AddOp
104  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
105  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
106 
107  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
108  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
109 
110  // tosa::SubOp
111  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
112  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
113 
114  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
115  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
116 
117  // tosa::IntDivOp
118  if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
119  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
120 
121  // tosa::ReciprocalOp
122  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
123  auto one =
124  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
125  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
126  }
127 
128  // tosa::MulOp
129  if (isa<tosa::MulOp>(op)) {
130  auto shiftVal = cast<tosa::MulOp>(op).getShift();
131  DenseElementsAttr shiftElem;
132  if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
133  (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
134  return nullptr;
135  }
136 
137  int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
138 
139  if (isa<FloatType>(elementTy)) {
140  if (shift != 0) {
141  (void)rewriter.notifyMatchFailure(op,
142  "Cannot have shift value for float");
143  return nullptr;
144  }
145  return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
146  }
147 
148  if (isa<IntegerType>(elementTy)) {
149  Value a = args[0];
150  Value b = args[1];
151 
152  if (shift > 0) {
153  auto shiftConst =
154  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
155  if (!a.getType().isInteger(32))
156  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
157 
158  if (!b.getType().isInteger(32))
159  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
160 
161  auto result = rewriter.create<tosa::ApplyScaleOp>(
162  loc, rewriter.getI32Type(), a, b, shiftConst,
163  rewriter.getStringAttr("SINGLE_ROUND"));
164 
165  if (elementTy.isInteger(32))
166  return result;
167 
168  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
169  }
170 
171  int aWidth = a.getType().getIntOrFloatBitWidth();
172  int bWidth = b.getType().getIntOrFloatBitWidth();
173  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
174 
175  if (aWidth < cWidth)
176  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
177  if (bWidth < cWidth)
178  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
179 
180  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
181  }
182  }
183 
184  // tosa::NegateOp
185  if (isa<tosa::NegateOp>(op)) {
186  auto negate = cast<tosa::NegateOp>(op);
187 
188  FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
189  if (failed(maybeInZp)) {
190  (void)rewriter.notifyMatchFailure(
191  op, "input1 zero point cannot be statically determined");
192  return nullptr;
193  }
194 
195  FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
196  if (failed(maybeOutZp)) {
197  (void)rewriter.notifyMatchFailure(
198  op, "output zero point cannot be statically determined");
199  return nullptr;
200  }
201 
202  int64_t inZp = *maybeInZp;
203  int64_t outZp = *maybeOutZp;
204 
205  if (isa<FloatType>(elementTy))
206  return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
207 
208  if (isa<IntegerType>(elementTy)) {
209  if (!inZp && !outZp) {
210  auto constant = rewriter.create<arith::ConstantOp>(
211  loc, IntegerAttr::get(elementTy, 0));
212  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
213  args[0]);
214  }
215 
216  // Compute the maximum value that can occur in the intermediate buffer.
217  const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
218  const int64_t zpAdd = inZp + outZp;
219  const int64_t maxValue =
220  APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
221  std::abs(zpAdd) + 1;
222 
223  // Convert that maximum value into the maximum bitwidth needed to
224  // represent it. We assume 48-bit numbers may be supported further in
225  // the pipeline.
226  int intermediateBitWidth = 64;
227  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
228  intermediateBitWidth = 16;
229  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
230  intermediateBitWidth = 32;
231  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
232  intermediateBitWidth = 48;
233  }
234 
235  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
236  Value zpAddValue = rewriter.create<arith::ConstantOp>(
237  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
238 
239  // The negation can be applied by doing:
240  // outputValue = inZp + outZp - inputValue
241  auto ext =
242  rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
243  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
244 
245  // Clamp to the negation range.
246  Value min = rewriter.create<arith::ConstantIntOp>(
247  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
248  intermediateType);
249  Value max = rewriter.create<arith::ConstantIntOp>(
250  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
251  intermediateType);
252  auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
253 
254  // Truncate to the final value.
255  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
256  }
257  }
258 
259  // tosa::BitwiseAndOp
260  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
261  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
262 
263  // tosa::BitwiseOrOp
264  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
265  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
266 
267  // tosa::BitwiseNotOp
268  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
269  auto allOnesAttr = rewriter.getIntegerAttr(
270  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
271  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
272  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
273  }
274 
275  // tosa::BitwiseXOrOp
276  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
277  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
278 
279  // tosa::LogicalLeftShiftOp
280  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
281  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
282 
283  // tosa::LogicalRightShiftOp
284  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
285  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
286 
287  // tosa::ArithmeticRightShiftOp
288  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
289  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
290  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
291  if (!round) {
292  return result;
293  }
294 
295  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
296  auto one =
297  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
298  auto zero =
299  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
300  auto i1one =
301  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
302 
303  // Checking that input2 != 0
304  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
305  loc, arith::CmpIPredicate::sgt, args[1], zero);
306 
307  // Checking for the last bit of input1 to be 1
308  auto subtract =
309  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
310  auto shifted =
311  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
312  ->getResults();
313  auto truncated =
314  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
315  auto isInputOdd =
316  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
317 
318  auto shouldRound = rewriter.create<arith::AndIOp>(
319  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
320  auto extended =
321  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
322  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
323  }
324 
325  // tosa::ClzOp
326  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
327  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
328  }
329 
330  // tosa::LogicalAnd
331  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
332  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
333 
334  // tosa::LogicalNot
335  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
336  auto one = rewriter.create<arith::ConstantOp>(
337  loc, rewriter.getIntegerAttr(elementTy, 1));
338  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
339  }
340 
341  // tosa::LogicalOr
342  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
343  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
344 
345  // tosa::LogicalXor
346  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
347  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
348 
349  // tosa::PowOp
350  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
351  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
352 
353  // tosa::RsqrtOp
354  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
355  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
356 
357  // tosa::LogOp
358  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
359  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
360 
361  // tosa::ExpOp
362  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
363  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
364 
365  // tosa::SinOp
366  if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
367  return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
368 
369  // tosa::CosOp
370  if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
371  return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
372 
373  // tosa::TanhOp
374  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
375  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
376 
377  // tosa::ErfOp
378  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
379  return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
380 
381  // tosa::GreaterOp
382  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
383  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
384  args[0], args[1]);
385 
386  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
387  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
388  args[0], args[1]);
389 
390  // tosa::GreaterEqualOp
391  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
392  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
393  args[0], args[1]);
394 
395  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
396  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
397  args[0], args[1]);
398 
399  // tosa::EqualOp
400  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
401  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
402  args[0], args[1]);
403 
404  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
405  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
406  args[0], args[1]);
407 
408  // tosa::SelectOp
409  if (isa<tosa::SelectOp>(op)) {
410  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
411  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
412  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
413  }
414 
415  // tosa::MaximumOp
416  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
417  auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
418  return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
419  rewriter, args[0], args[1], max);
420  }
421 
422  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
423  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
424  }
425 
426  // tosa::MinimumOp
427  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
428  auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
429  return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
430  rewriter, args[0], args[1], min);
431  }
432 
433  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
434  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
435  }
436 
437  // tosa::CeilOp
438  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
439  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
440 
441  // tosa::FloorOp
442  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
443  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
444 
445  // tosa::ClampOp
446  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
447  bool losesInfo = false;
448  APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
449  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
450  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
451  APFloat::rmNearestTiesToEven, &losesInfo);
452  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
453  APFloat::rmNearestTiesToEven, &losesInfo);
454  auto min = rewriter.create<arith::ConstantOp>(
455  loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
456  auto max = rewriter.create<arith::ConstantOp>(
457  loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
458  auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
459 
460  auto clampOp = llvm::cast<tosa::ClampOp>(op);
461  const auto nanMode = clampOp.getNanMode();
462 
463  // NaN propagation has no meaning for non floating point types.
464  if (!isa<FloatType>(elementTy))
465  return result;
466 
467  // In the case of "PROPAGATE" semantics no compare and selection is
468  // required.
469  if (nanMode == "PROPAGATE")
470  return result;
471 
472  // In the case of "IGNORE" semantics materialize a comparison
473  // of the current operand to the reduction which will return true for a NaN
474  // argument and then selects between the initial reduction value and the
475  // calculated result based on whether the argument is NaN or not. In pseudo
476  // code:
477  //
478  // reduce<op>(x, init):
479  // result = op(init, x)
480  // return init if x == NaN else result
481 
482  // Unordered comparison of NaN against itself will always return true.
483  Value isNaN = rewriter.create<arith::CmpFOp>(
484  op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
485  // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
486  // is NaN.
487  return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
488  }
489 
490  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
491  auto intTy = cast<IntegerType>(elementTy);
492  int64_t min =
493  cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
494  int64_t max =
495  cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
496 
497  int64_t minRepresentable = std::numeric_limits<int64_t>::min();
498  int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
499  if (intTy.isUnsignedInteger()) {
500  minRepresentable = 0;
501  if (intTy.getIntOrFloatBitWidth() <= 63) {
502  maxRepresentable =
503  (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
504  .getZExtValue();
505  }
506  } else if (intTy.getIntOrFloatBitWidth() <= 64) {
507  // Ensure that min & max fit into signed n-bit constants.
508  minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
509  .getSExtValue();
510  maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
511  .getSExtValue();
512  }
513  // Ensure that the bounds are representable as n-bit signed/unsigned
514  // integers.
515  min = std::max(min, minRepresentable);
516  max = std::max(max, minRepresentable);
517  min = std::min(min, maxRepresentable);
518  max = std::min(max, maxRepresentable);
519 
520  auto minVal = rewriter.create<arith::ConstantIntOp>(
521  loc, min, intTy.getIntOrFloatBitWidth());
522  auto maxVal = rewriter.create<arith::ConstantIntOp>(
523  loc, max, intTy.getIntOrFloatBitWidth());
524  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
525  intTy.isUnsignedInteger());
526  }
527 
528  // tosa::SigmoidOp
529  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
530  auto one =
531  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
532  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
533  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
534  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
535  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
536  }
537 
538  // tosa::CastOp
539  if (isa<tosa::CastOp>(op)) {
540  Type srcTy = elementTy;
541  Type dstTy = resultTypes.front();
542  if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
543  (void)rewriter.notifyMatchFailure(op, "unsupported type");
544  return nullptr;
545  }
546 
547  bool bitExtend =
549 
550  if (srcTy == dstTy)
551  return args.front();
552 
553  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
554  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
555  std::nullopt);
556 
557  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
558  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
559  std::nullopt);
560 
561  // 1-bit integers need to be treated as signless.
562  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
563  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
564  std::nullopt);
565 
566  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
567  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
568  std::nullopt);
569 
570  // Unsigned integers need an unrealized cast so that they can be passed
571  // to UIToFP.
572  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
573  auto unrealizedCast =
574  rewriter
575  .create<UnrealizedConversionCastOp>(
576  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
577  args[0])
578  .getResult(0);
579  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
580  unrealizedCast);
581  }
582 
583  // All other si-to-fp conversions should be handled by SIToFP.
584  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
585  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
586  std::nullopt);
587 
588  // Casting to boolean, floats need to only be checked as not-equal to zero.
589  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
590  Value zero = rewriter.create<arith::ConstantOp>(
591  loc, rewriter.getFloatAttr(srcTy, 0.0));
592  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
593  args.front(), zero);
594  }
595 
596  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
597  auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
598 
599  const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
600  // Check whether neither int min nor int max can be represented in the
601  // input floating-point type due to too short exponent range.
602  if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
603  APFloat::semanticsMaxExponent(fltSemantics)) {
604  // Use cmp + select to replace infinites by int min / int max. Other
605  // integral values can be represented in the integer space.
606  auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
607  auto posInf = rewriter.create<arith::ConstantOp>(
608  loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
609  APFloat::getInf(fltSemantics)));
610  auto negInf = rewriter.create<arith::ConstantOp>(
611  loc, rewriter.getFloatAttr(
612  getElementTypeOrSelf(srcTy),
613  APFloat::getInf(fltSemantics, /*Negative=*/true)));
614  auto overflow = rewriter.create<arith::CmpFOp>(
615  loc, arith::CmpFPredicate::UEQ, rounded, posInf);
616  auto underflow = rewriter.create<arith::CmpFOp>(
617  loc, arith::CmpFPredicate::UEQ, rounded, negInf);
618  auto intMin = rewriter.create<arith::ConstantOp>(
619  loc, rewriter.getIntegerAttr(
620  getElementTypeOrSelf(dstTy),
621  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
622  auto intMax = rewriter.create<arith::ConstantOp>(
623  loc, rewriter.getIntegerAttr(
624  getElementTypeOrSelf(dstTy),
625  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
626  auto maxClamped =
627  rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
628  return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
629  maxClamped);
630  }
631 
632  auto intMinFP = rewriter.create<arith::ConstantOp>(
633  loc, rewriter.getFloatAttr(
634  getElementTypeOrSelf(srcTy),
635  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
636  .getSExtValue()));
637 
638  // Check whether the mantissa has enough bits to represent int max.
639  if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
640  dstTy.getIntOrFloatBitWidth() - 1) {
641  // Int min can also be represented since it is a power of two and thus
642  // consists of a single leading bit. Therefore we can clamp the input
643  // in the floating-point domain.
644 
645  auto intMaxFP = rewriter.create<arith::ConstantOp>(
646  loc, rewriter.getFloatAttr(
647  getElementTypeOrSelf(srcTy),
648  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
649  .getSExtValue()));
650 
651  Value clamped =
652  clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
653  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
654  }
655 
656  // Due to earlier check we know exponant range is big enough to represent
657  // int min. We can therefore rely on int max + 1 being representable as
658  // well because it's just int min with a positive sign. So clamp the min
659  // value and compare against that to select the max int value if needed.
660  auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
661  loc, rewriter.getFloatAttr(
662  getElementTypeOrSelf(srcTy),
663  static_cast<double>(
664  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
665  .getSExtValue()) +
666  1.0f));
667 
668  auto intMax = rewriter.create<arith::ConstantOp>(
669  loc, rewriter.getIntegerAttr(
670  getElementTypeOrSelf(dstTy),
671  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
672  auto minClampedFP =
673  rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
674  auto minClamped =
675  rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
676  auto overflow = rewriter.create<arith::CmpFOp>(
677  loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
678  return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
679  minClamped);
680  }
681 
682  // Casting to boolean, integers need to only be checked as not-equal to
683  // zero.
684  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
685  Value zero = rewriter.create<arith::ConstantIntOp>(
686  loc, 0, srcTy.getIntOrFloatBitWidth());
687  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
688  args.front(), zero);
689  }
690 
691  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
692  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
693  std::nullopt);
694 
695  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
696  return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
697  }
698  }
699 
700  (void)rewriter.notifyMatchFailure(
701  op, "unhandled op for linalg body calculation for elementwise op");
702  return nullptr;
703 }
704 
706 
707 // Emit an 'arith.constant' op for the given index if it has not been created
708 // yet, or return an existing constant. This will prevent an excessive creation
709 // of redundant constants, easing readability of emitted code for unit tests.
711  IndexPool &indexPool, int64_t index) {
712  auto [it, inserted] = indexPool.try_emplace(index);
713  if (inserted)
714  it->second =
715  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
716  return it->second;
717 }
718 
720  IndexPool &indexPool, Value tensor, int64_t index) {
721  auto indexValue = createIndex(rewriter, loc, indexPool, index);
722  return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
723 }
724 
726  IndexPool &indexPool, Value tensor,
727  int64_t index) {
728  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
729  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
730  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
731  if (shapedType.isDynamicDim(index))
732  return getTensorDim(rewriter, loc, indexPool, tensor, index);
733  return rewriter.getIndexAttr(shapedType.getDimSize(index));
734 }
735 
736 static bool operandsAndResultsRanked(Operation *operation) {
737  auto isRanked = [](Value value) {
738  return isa<RankedTensorType>(value.getType());
739  };
740  return llvm::all_of(operation->getOperands(), isRanked) &&
741  llvm::all_of(operation->getResults(), isRanked);
742 }
743 
744 // Compute the runtime dimension size for dimension 'dim' of the output by
745 // inspecting input 'operands', all of which are expected to have the same rank.
746 // This function returns a pair {targetSize, masterOperand}.
747 //
748 // The runtime size of the output dimension is returned either as a statically
749 // computed attribute or as a runtime SSA value.
750 //
751 // If the target size was inferred directly from one dominating operand, that
752 // operand is returned in 'masterOperand'. If the target size is inferred from
753 // multiple operands, 'masterOperand' is set to nullptr.
754 static std::pair<OpFoldResult, Value>
756  ValueRange operands, int64_t dim) {
757  // If any input operand contains a static size greater than 1 for this
758  // dimension, that is the target size. An occurrence of an additional static
759  // dimension greater than 1 with a different value is undefined behavior.
760  for (auto operand : operands) {
761  auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
762  if (!ShapedType::isDynamic(size) && size > 1)
763  return {rewriter.getIndexAttr(size), operand};
764  }
765 
766  // Filter operands with dynamic dimension
767  auto operandsWithDynamicDim =
768  llvm::filter_to_vector(operands, [&](Value operand) {
769  return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
770  });
771 
772  // If no operand has a dynamic dimension, it means all sizes were 1
773  if (operandsWithDynamicDim.empty())
774  return {rewriter.getIndexAttr(1), operands.front()};
775 
776  // Emit code that computes the runtime size for this dimension. If there is
777  // only one operand with a dynamic dimension, it is considered the master
778  // operand that determines the runtime size of the output dimension.
779  auto targetSize =
780  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
781  if (operandsWithDynamicDim.size() == 1)
782  return {targetSize, operandsWithDynamicDim[0]};
783 
784  // Calculate maximum size among all dynamic dimensions
785  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
786  auto nextSize =
787  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
788  targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
789  }
790  return {targetSize, nullptr};
791 }
792 
793 // Compute the runtime output size for all dimensions. This function returns
794 // a pair {targetShape, masterOperands}.
795 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
797  IndexPool &indexPool, ValueRange operands) {
798  assert(!operands.empty());
799  auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
800  SmallVector<OpFoldResult> targetShape;
801  SmallVector<Value> masterOperands;
802  for (auto dim : llvm::seq<int64_t>(0, rank)) {
803  auto [targetSize, masterOperand] =
804  computeTargetSize(rewriter, loc, indexPool, operands, dim);
805  targetShape.push_back(targetSize);
806  masterOperands.push_back(masterOperand);
807  }
808  return {targetShape, masterOperands};
809 }
810 
812  IndexPool &indexPool, Value operand,
813  int64_t dim, OpFoldResult targetSize,
814  Value masterOperand) {
815  // Nothing to do if this is a static dimension
816  auto rankedTensorType = cast<RankedTensorType>(operand.getType());
817  if (!rankedTensorType.isDynamicDim(dim))
818  return operand;
819 
820  // If the target size for this dimension was directly inferred by only taking
821  // this operand into account, there is no need to broadcast. This is an
822  // optimization that will prevent redundant control flow, and constitutes the
823  // main motivation for tracking "master operands".
824  if (operand == masterOperand)
825  return operand;
826 
827  // Affine maps for 'linalg.generic' op
828  auto rank = rankedTensorType.getRank();
829  SmallVector<AffineExpr> affineExprs;
830  for (auto index : llvm::seq<int64_t>(0, rank)) {
831  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
832  : rewriter.getAffineDimExpr(index);
833  affineExprs.push_back(affineExpr);
834  }
835  auto broadcastAffineMap =
836  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
837  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
838  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
839 
840  // Check if broadcast is necessary
841  auto one = createIndex(rewriter, loc, indexPool, 1);
842  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
843  auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
844  loc, arith::CmpIPredicate::eq, runtimeSize, one);
845 
846  // Emit 'then' region of 'scf.if'
847  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
848  // It is not safe to cache constants across regions.
849  // New constants could potentially violate dominance requirements.
850  IndexPool localPool;
851 
852  // Emit 'tensor.empty' op
853  SmallVector<OpFoldResult> outputTensorShape;
854  for (auto index : llvm::seq<int64_t>(0, rank)) {
855  auto size = index == dim ? targetSize
856  : getOrFoldTensorDim(rewriter, loc, localPool,
857  operand, index);
858  outputTensorShape.push_back(size);
859  }
860  Value outputTensor = opBuilder.create<tensor::EmptyOp>(
861  loc, outputTensorShape, rankedTensorType.getElementType());
862 
863  // Emit 'linalg.generic' op
864  auto resultTensor =
865  opBuilder
866  .create<linalg::GenericOp>(
867  loc, outputTensor.getType(), operand, outputTensor, affineMaps,
869  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
870  // Emit 'linalg.yield' op
871  opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
872  })
873  .getResult(0);
874 
875  // Cast to original operand type if necessary
876  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
877  loc, operand.getType(), resultTensor);
878 
879  // Emit 'scf.yield' op
880  opBuilder.create<scf::YieldOp>(loc, castResultTensor);
881  };
882 
883  // Emit 'else' region of 'scf.if'
884  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
885  opBuilder.create<scf::YieldOp>(loc, operand);
886  };
887 
888  // Emit 'scf.if' op
889  auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
890  emitThenRegion, emitElseRegion);
891  return ifOp.getResult(0);
892 }
893 
895  IndexPool &indexPool, Value operand,
896  ArrayRef<OpFoldResult> targetShape,
897  ArrayRef<Value> masterOperands) {
898  int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
899  assert((int64_t)targetShape.size() == rank);
900  assert((int64_t)masterOperands.size() == rank);
901  for (auto index : llvm::seq<int64_t>(0, rank))
902  operand =
903  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
904  targetShape[index], masterOperands[index]);
905  return operand;
906 }
907 
908 static SmallVector<Value>
910  IndexPool &indexPool, ValueRange operands,
911  ArrayRef<OpFoldResult> targetShape,
912  ArrayRef<Value> masterOperands) {
913  // No need to broadcast for unary operations
914  if (operands.size() == 1)
915  return operands;
916 
917  // Broadcast dynamic dimensions operand by operand
918  return llvm::map_to_vector(operands, [&](Value operand) {
919  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
920  targetShape, masterOperands);
921  });
922 }
923 
924 static LogicalResult
926  Operation *operation, ValueRange operands,
927  ArrayRef<OpFoldResult> targetShape,
928  const TypeConverter &converter) {
929  // Generate output tensor
930  auto resultType = cast_or_null<RankedTensorType>(
931  converter.convertType(operation->getResultTypes().front()));
932  if (!resultType) {
933  return rewriter.notifyMatchFailure(operation, "failed to convert type");
934  }
935  Value outputTensor = rewriter.create<tensor::EmptyOp>(
936  loc, targetShape, resultType.getElementType());
937 
938  // Create affine maps. Input affine maps broadcast static dimensions of size
939  // 1. The output affine map is an identity map.
940  //
941  auto rank = resultType.getRank();
942  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
943  auto shape = cast<ShapedType>(operand.getType()).getShape();
944  SmallVector<AffineExpr> affineExprs;
945  for (auto it : llvm::enumerate(shape)) {
946  // Prefer producting identity maps whenever possible (i.e. no broadcasting
947  // needed) because some transforms (like reshape folding)
948  // do not support affine constant exprs.
949  bool requiresBroadcast =
950  (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
951  auto affineExpr = requiresBroadcast
952  ? rewriter.getAffineConstantExpr(0)
953  : rewriter.getAffineDimExpr(it.index());
954  affineExprs.push_back(affineExpr);
955  }
956  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
957  });
958  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
959 
960  // Emit 'linalg.generic' op
961  bool encounteredError = false;
962  auto linalgOp = rewriter.create<linalg::GenericOp>(
963  loc, outputTensor.getType(), operands, outputTensor, affineMaps,
965  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
967  operation, blockArgs.take_front(operation->getNumOperands()),
968  {resultType.getElementType()}, rewriter);
969  if (!opResult) {
970  encounteredError = true;
971  return;
972  }
973  opBuilder.create<linalg::YieldOp>(loc, opResult);
974  });
975  if (encounteredError)
976  return rewriter.notifyMatchFailure(
977  operation, "unable to create linalg.generic body for elementwise op");
978 
979  // Cast 'linalg.generic' result into original result type if needed
980  auto castResult = rewriter.createOrFold<tensor::CastOp>(
981  loc, resultType, linalgOp->getResult(0));
982  rewriter.replaceOp(operation, castResult);
983  return success();
984 }
985 
987  ValueRange operands) {
988  // Shift cannot broadcast
989  if (isa<tosa::MulOp>(operation))
990  return operands.take_front(2);
991  // Input1_zp and output_zp cannot broadcast
992  if (isa<tosa::NegateOp>(operation))
993  return operands.take_front(1);
994  return operands;
995 }
996 
997 static LogicalResult
999  ConversionPatternRewriter &rewriter,
1000  const TypeConverter &converter) {
1001 
1002  // Collect op properties
1003  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1004  assert(operation->getNumOperands() >= 1 &&
1005  "elementwise op expects at least 1 operand");
1006  if (!operandsAndResultsRanked(operation))
1007  return rewriter.notifyMatchFailure(operation,
1008  "Unranked tensors not supported");
1009 
1010  // Lower operation
1011  IndexPool indexPool;
1012  auto loc = operation->getLoc();
1013  auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1014  auto [targetShape, masterOperands] =
1015  computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
1016  auto broadcastOperands =
1017  broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
1018  targetShape, masterOperands);
1019  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
1020  targetShape, converter);
1021 }
1022 
1023 // Returns the constant initial value for a given reduction operation. The
1024 // attribute type varies depending on the element type required.
1025 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1026  PatternRewriter &rewriter) {
1027  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1028  return rewriter.getFloatAttr(elementTy, 0.0);
1029 
1030  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1031  return rewriter.getIntegerAttr(elementTy, 0);
1032 
1033  if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1034  return rewriter.getFloatAttr(elementTy, 1.0);
1035 
1036  if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1037  return rewriter.getIntegerAttr(elementTy, 1);
1038 
1039  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1040  return rewriter.getFloatAttr(
1041  elementTy, APFloat::getLargest(
1042  cast<FloatType>(elementTy).getFloatSemantics(), false));
1043 
1044  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1045  return rewriter.getIntegerAttr(
1046  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
1047 
1048  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1049  return rewriter.getFloatAttr(
1050  elementTy, APFloat::getLargest(
1051  cast<FloatType>(elementTy).getFloatSemantics(), true));
1052 
1053  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1054  return rewriter.getIntegerAttr(
1055  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1056 
1057  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1058  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1059 
1060  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1061  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1062 
1063  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1064  return rewriter.getFloatAttr(
1065  elementTy, APFloat::getLargest(
1066  cast<FloatType>(elementTy).getFloatSemantics(), true));
1067 
1068  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1069  return rewriter.getIntegerAttr(
1070  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1071 
1072  return {};
1073 }
1074 
1075 // Creates the body calculation for a reduction. The operations vary depending
1076 // on the input type.
1078  ValueRange args,
1079  Type elementTy,
1080  PatternRewriter &rewriter) {
1081  Location loc = op->getLoc();
1082  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1083  return rewriter.create<arith::AddFOp>(loc, args);
1084  }
1085 
1086  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1087  return rewriter.create<arith::AddIOp>(loc, args);
1088  }
1089 
1090  if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1091  return rewriter.create<arith::MulFOp>(loc, args);
1092  }
1093 
1094  if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1095  return rewriter.create<arith::MulIOp>(loc, args);
1096  }
1097 
1098  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1099  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1100  }
1101 
1102  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1103  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1104  }
1105 
1106  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1107  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1108  }
1109 
1110  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1111  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1112  }
1113 
1114  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1115  return rewriter.create<arith::AndIOp>(loc, args);
1116 
1117  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1118  return rewriter.create<arith::OrIOp>(loc, args);
1119 
1120  return {};
1121 }
1122 
1123 // Performs the match and rewrite for reduction operations. This includes
1124 // declaring a correctly sized initial value, and the linalg.generic operation
1125 // that reduces across the specified axis.
1126 template <typename OpTy>
1127 static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1128  PatternRewriter &rewriter) {
1129  auto loc = op->getLoc();
1130  auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1131  auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1132  if (!inputTy || !resultTy)
1133  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1134 
1135  auto elementTy = resultTy.getElementType();
1136  Value input = op->getOperand(0);
1137 
1138  SmallVector<int64_t> reduceShape;
1139  SmallVector<Value> dynDims;
1140  for (unsigned i = 0; i < inputTy.getRank(); i++) {
1141  if (axis != i) {
1142  reduceShape.push_back(inputTy.getDimSize(i));
1143  if (inputTy.isDynamicDim(i))
1144  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1145  }
1146  }
1147 
1148  SmallVector<Value> inputs, outputs;
1149  inputs.push_back(input);
1150 
1151  // First fill the output buffer with the init value.
1152  auto emptyTensor =
1153  rewriter
1154  .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1155  dynDims)
1156  .getResult();
1157 
1158  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1159  if (!fillValueAttr)
1160  return rewriter.notifyMatchFailure(
1161  op, "No initial value found for reduction operation");
1162 
1163  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1164  auto filledTensor = rewriter
1165  .create<linalg::FillOp>(loc, ValueRange{fillValue},
1166  ValueRange{emptyTensor})
1167  .result();
1168  outputs.push_back(filledTensor);
1169 
1170  bool isNanIgnoreMode = false;
1171  if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1172  std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1173  // NaN propagation has no meaning for non floating point types.
1174  if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
1175  isNanIgnoreMode = true;
1176  // Because the TOSA spec requires the result be NaN iff all elements in
1177  // the reduction are NaN we can't simply perform a compare and select.
1178  // Additionally we have to keep track of whether we've seen any non-NaN
1179  // values and then do a final select based on this predicate.
1180  auto trueAttr = rewriter.getBoolAttr(true);
1181  auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
1182  auto emptyBoolTensor =
1183  rewriter
1184  .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
1185  dynDims)
1186  .getResult();
1187  auto allResultsNaNTensor =
1188  rewriter
1189  .create<linalg::FillOp>(loc, ValueRange{trueValue},
1190  ValueRange{emptyBoolTensor})
1191  .result();
1192  // Note that because the linalg::ReduceOp has two variadic arguments
1193  // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1194  // need to have the same number of inputs and outputs.
1195  //
1196  // The second input isn't actually used anywhere since the value used to
1197  // update the NaN flag is calculated inside the body of the reduction and
1198  // then used to update an out value.
1199  // In order to satisfy type constraints we just pass another copy of the
1200  // input here.
1201  inputs.push_back(input);
1202  outputs.push_back(allResultsNaNTensor);
1203  }
1204  }
1205 
1206  bool didEncounterError = false;
1207  linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
1208  loc, inputs, outputs, axis,
1209  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1210  std::array<Value, 2> binaryArgs{
1211  blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1213  op, binaryArgs, elementTy, rewriter);
1214  if (result)
1215  didEncounterError = true;
1216 
1217  SmallVector<Value> resultsToYield;
1218  if (isNanIgnoreMode) {
1219  auto inputValue = blockArgs[0];
1220  auto initialValue = blockArgs[2];
1221  auto oldAllResultsNanFlagValue = blockArgs[3];
1222 
1223  // Unordered comparison of NaN against itself will always return true.
1224  Value isNaN = nestedBuilder.create<arith::CmpFOp>(
1225  op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1226  // If we've encountered a NaN, take the non-NaN value.
1227  auto selectOp = nestedBuilder.create<arith::SelectOp>(
1228  op->getLoc(), isNaN, initialValue, result);
1229  // Update the flag which keeps track of whether we have seen a non-NaN
1230  // value.
1231  auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
1232  op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1233  resultsToYield.push_back(selectOp);
1234  resultsToYield.push_back(newAllResultsNanFlagValue);
1235  } else {
1236  resultsToYield.push_back(result);
1237  }
1238  nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
1239  });
1240 
1241  if (!didEncounterError)
1242  return rewriter.notifyMatchFailure(
1243  op, "unable to create linalg.generic body for reduce op");
1244 
1245  if (isNanIgnoreMode) {
1246  // Materialize a check to see whether we encountered any non-NaN values, if
1247  // we didn't we need to select a tensor of NaNs since the result will just
1248  // be the initial identity value propagated through all the compares and
1249  // selects inside the reduction.
1250 
1251  // Create a tensor full of NaNs.
1252  auto nanValueAttr = rewriter.getFloatAttr(
1253  elementTy,
1254  APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1255  auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
1256  auto emptyNanTensor =
1257  rewriter
1258  .create<tensor::EmptyOp>(loc, reduceShape,
1259  resultTy.getElementType(), dynDims)
1260  .getResult();
1261  auto nanFilledTensor =
1262  rewriter
1263  .create<linalg::FillOp>(loc, ValueRange{nanValue},
1264  ValueRange{emptyNanTensor})
1265  .result();
1266 
1267  // Create an empty tensor, non need to fill this since it will be
1268  // overwritten by the select.
1269  auto finalEmptyTensor =
1270  rewriter
1271  .create<tensor::EmptyOp>(loc, reduceShape,
1272  resultTy.getElementType(), dynDims)
1273  .getResult();
1274 
1275  // Do a selection between the tensors akin to:
1276  // result = NaN if "all results NaN" else result.
1277  SmallVector<Value> ins, outs;
1278  ins.push_back(linalgOp->getOpResult(1));
1279  ins.push_back(nanFilledTensor);
1280  ins.push_back(linalgOp->getResult(0));
1281  outs.push_back(finalEmptyTensor);
1282  auto linalgSelect =
1283  rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
1284  linalgOp = linalgSelect;
1285  }
1286 
1287  SmallVector<ReassociationExprs, 4> reassociationMap;
1288  uint64_t expandInputRank =
1289  cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
1290  reassociationMap.resize(expandInputRank);
1291 
1292  for (uint64_t i = 0; i < expandInputRank; i++) {
1293  int32_t dimToPush = i > axis ? i + 1 : i;
1294  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1295  }
1296 
1297  if (expandInputRank != 0) {
1298  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1299  reassociationMap[expandedDim].push_back(
1300  rewriter.getAffineDimExpr(expandedDim + 1));
1301  }
1302 
1303  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1304  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1305  // not have access to such information. This matters when handling dynamically
1306  // sized tensors.
1307  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1308  op, resultTy, linalgOp->getResults()[0], reassociationMap);
1309  return success();
1310 }
1311 
1312 namespace {
1313 
1314 template <typename SrcOp>
1315 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1316 public:
1319 
1320  LogicalResult
1321  matchAndRewrite(SrcOp op, OpAdaptor operands,
1322  ConversionPatternRewriter &rewriter) const final {
1324  op, operands.getOperands(), rewriter, *this->getTypeConverter());
1325  }
1326 };
1327 
1328 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1329 public:
1331 
1332  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1333  PatternRewriter &rewriter) const final {
1334  auto loc = op.getLoc();
1335  auto input = op.getInput();
1336  auto inputTy = cast<ShapedType>(op.getInput().getType());
1337  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1338  unsigned rank = inputTy.getRank();
1339 
1340  // This is an illegal configuration. terminate and log an error
1341  if (op.getRoundingMode() == "INEXACT_ROUND")
1342  return rewriter.notifyMatchFailure(
1343  op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1344  "currently supported");
1345  if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
1346  return rewriter.notifyMatchFailure(
1347  op, "tosa.rescale requires scale32 for double_round to be true");
1348 
1349  if (!isa<IntegerType>(inputTy.getElementType()))
1350  return rewriter.notifyMatchFailure(op, "only support integer type");
1351 
1352  SmallVector<Value> dynDims;
1353  for (int i = 0; i < outputTy.getRank(); i++) {
1354  if (outputTy.isDynamicDim(i)) {
1355  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1356  }
1357  }
1358 
1359  // The shift and multiplier values.
1360  DenseElementsAttr shiftElems;
1361  if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1362  return rewriter.notifyMatchFailure(
1363  op, "tosa.rescale requires constant shift input values");
1364 
1365  DenseElementsAttr multiplierElems;
1366  if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1367  return rewriter.notifyMatchFailure(
1368  op, "tosa.rescale requires constant multiplier input values");
1369 
1370  llvm::SmallVector<int8_t> shiftValues =
1371  llvm::to_vector(shiftElems.getValues<int8_t>());
1372  // explicit cast is required here
1373  llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1374  llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1375  [](IntegerAttr attr) -> int32_t {
1376  return static_cast<int32_t>(attr.getInt());
1377  }));
1378 
1379  // If we shift by more than the bitwidth, this just sets to 0.
1380  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1381  if (shiftValues[i] > 63) {
1382  shiftValues[i] = 0;
1383  multiplierValues[i] = 0;
1384  }
1385  }
1386 
1387  // Double round only occurs if shift is greater than 31, check that this
1388  // is ever true.
1389 
1390  bool doubleRound =
1391  op.getRoundingMode() == "DOUBLE_ROUND" &&
1392  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1393  StringAttr roundingMode = doubleRound
1394  ? rewriter.getStringAttr("DOUBLE_ROUND")
1395  : rewriter.getStringAttr("SINGLE_ROUND");
1396 
1397  SmallVector<AffineMap> indexingMaps = {
1398  rewriter.getMultiDimIdentityMap(rank)};
1399  SmallVector<Value, 4> genericInputs = {input};
1400 
1401  // If we are rescaling per-channel then we need to store the multiplier
1402  // values in a buffer.
1403  Value multiplierConstant;
1404  int64_t multiplierArg = 0;
1405  if (multiplierValues.size() == 1) {
1406  multiplierConstant = rewriter.create<arith::ConstantOp>(
1407  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1408  } else {
1409  SmallVector<AffineExpr, 2> multiplierExprs{
1410  rewriter.getAffineDimExpr(rank - 1)};
1411  auto multiplierType =
1412  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1413  rewriter.getI32Type());
1414  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1415  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1416 
1417  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1418  /*symbolCount=*/0, multiplierExprs,
1419  rewriter.getContext()));
1420 
1421  multiplierArg = indexingMaps.size() - 1;
1422  }
1423 
1424  // If we are rescaling per-channel then we need to store the shift
1425  // values in a buffer.
1426  Value shiftConstant;
1427  int64_t shiftArg = 0;
1428  if (shiftValues.size() == 1) {
1429  shiftConstant = rewriter.create<arith::ConstantOp>(
1430  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1431  } else {
1432  SmallVector<AffineExpr, 2> shiftExprs = {
1433  rewriter.getAffineDimExpr(rank - 1)};
1434  auto shiftType =
1435  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1436  rewriter.getIntegerType(8));
1437  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1438  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1439  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1440  /*symbolCount=*/0, shiftExprs,
1441  rewriter.getContext()));
1442  shiftArg = indexingMaps.size() - 1;
1443  }
1444 
1445  // Indexing maps for output values.
1446  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1447 
1448  // Construct the indexing maps needed for linalg.generic ops.
1449  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1450  loc, outputTy.getShape(), outputTy.getElementType(),
1451  ArrayRef<Value>({dynDims}));
1452 
1453  auto linalgOp = rewriter.create<linalg::GenericOp>(
1454  loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1455  getNParallelLoopsAttrs(rank),
1456  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1457  ValueRange blockArgs) {
1458  Value value = blockArgs[0];
1459  Type valueTy = value.getType();
1460 
1461  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1462  if (failed(maybeIZp)) {
1463  (void)rewriter.notifyMatchFailure(
1464  op, "input zero point cannot be statically determined");
1465  return;
1466  }
1467 
1468  const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1469  // Extend zeropoint for sub-32bits widths.
1470  const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1471  auto inputZp = nestedBuilder.create<arith::ConstantOp>(
1472  loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
1473  *maybeIZp));
1474 
1475  FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1476  if (failed(maybeOZp)) {
1477  (void)rewriter.notifyMatchFailure(
1478  op, "output zero point cannot be statically determined");
1479  return;
1480  };
1481 
1482  IntegerType outIntType =
1483  cast<IntegerType>(blockArgs.back().getType());
1484  unsigned outBitWidth = outIntType.getWidth();
1485  const int32_t outAttrBitwidth = 32;
1486  assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1487  auto outputZp = nestedBuilder.create<arith::ConstantOp>(
1488  loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
1489  *maybeOZp));
1490 
1491  Value multiplier = multiplierConstant ? multiplierConstant
1492  : blockArgs[multiplierArg];
1493  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1494 
1495  if (valueTy.getIntOrFloatBitWidth() < 32) {
1496  if (op.getInputUnsigned()) {
1497  value = nestedBuilder.create<arith::ExtUIOp>(
1498  nestedLoc, nestedBuilder.getI32Type(), value);
1499  } else {
1500  value = nestedBuilder.create<arith::ExtSIOp>(
1501  nestedLoc, nestedBuilder.getI32Type(), value);
1502  }
1503  }
1504 
1505  value =
1506  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1507 
1508  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1509  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1510  roundingMode);
1511 
1512  // Move to the new zero-point.
1513  value =
1514  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1515 
1516  // Saturate to the output size.
1517  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1518  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1519 
1520  // Unsigned integers have a difference output value.
1521  if (op.getOutputUnsigned()) {
1522  intMin = 0;
1523  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1524  }
1525 
1526  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1527  loc, nestedBuilder.getI32IntegerAttr(intMin));
1528  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1529  loc, nestedBuilder.getI32IntegerAttr(intMax));
1530 
1531  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1532  nestedBuilder, /*isUnsigned=*/false);
1533 
1534  if (outIntType.getWidth() < 32) {
1535  value = nestedBuilder.create<arith::TruncIOp>(
1536  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1537  value);
1538  }
1539 
1540  nestedBuilder.create<linalg::YieldOp>(loc, value);
1541  });
1542 
1543  rewriter.replaceOp(op, linalgOp->getResults());
1544  return success();
1545  }
1546 };
1547 
1548 // Handle the resize case where the input is a 1x1 image. This case
1549 // can entirely avoiding having extract operations which target much
1550 // more difficult to optimize away.
1551 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1552 public:
1554 
1555  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1556  PatternRewriter &rewriter) const final {
1557  Location loc = op.getLoc();
1558  ImplicitLocOpBuilder builder(loc, rewriter);
1559  auto input = op.getInput();
1560  auto inputTy = cast<RankedTensorType>(input.getType());
1561  auto resultTy = cast<RankedTensorType>(op.getType());
1562  const bool isBilinear = op.getMode() == "BILINEAR";
1563 
1564  auto inputH = inputTy.getDimSize(1);
1565  auto inputW = inputTy.getDimSize(2);
1566  auto outputH = resultTy.getDimSize(1);
1567  auto outputW = resultTy.getDimSize(2);
1568 
1569  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1570  return rewriter.notifyMatchFailure(
1571  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1572 
1573  // TODO(suderman): These string values should be declared the TOSA dialect.
1574  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1575  return rewriter.notifyMatchFailure(
1576  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1577 
1578  if (inputTy == resultTy) {
1579  rewriter.replaceOp(op, input);
1580  return success();
1581  }
1582 
1583  SmallVector<int64_t> scale;
1584  if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
1585  return failure();
1586  }
1587 
1588  // Collapse the unit width and height away.
1589  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1590  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1591  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1592  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1593  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1594 
1595  auto collapseTy =
1596  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1597  inputTy.getElementType());
1598  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1599  reassociationMap);
1600 
1601  // Get any dynamic shapes that appear in the input format.
1602  llvm::SmallVector<Value> outputDynSize;
1603  if (inputTy.isDynamicDim(0))
1604  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1605  if (inputTy.isDynamicDim(3))
1606  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1607 
1608  // Generate the elementwise operation for casting scaling the input value.
1609  auto genericTy = collapseTy.clone(resultTy.getElementType());
1610  Value empty = builder.create<tensor::EmptyOp>(
1611  genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1612  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1613  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1614  utils::IteratorType::parallel);
1615 
1616  auto generic = builder.create<linalg::GenericOp>(
1617  genericTy, ValueRange{collapse}, ValueRange{empty},
1618  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1619  [=](OpBuilder &b, Location loc, ValueRange args) {
1620  Value value = args[0];
1621  // This is the quantized case.
1622  if (inputTy.getElementType() != resultTy.getElementType()) {
1623  value =
1624  b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1625 
1626  if (isBilinear && scale[0] != 0) {
1627  Value scaleY = b.create<arith::ConstantOp>(
1628  loc, b.getI32IntegerAttr(scale[0]));
1629  value = b.create<arith::MulIOp>(loc, value, scaleY);
1630  }
1631 
1632  if (isBilinear && scale[2] != 0) {
1633  Value scaleX = b.create<arith::ConstantOp>(
1634  loc, b.getI32IntegerAttr(scale[2]));
1635  value = b.create<arith::MulIOp>(loc, value, scaleX);
1636  }
1637  }
1638 
1639  b.create<linalg::YieldOp>(loc, value);
1640  });
1641 
1642  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1643  op, resultTy, generic.getResults()[0], reassociationMap);
1644  return success();
1645  }
1646 };
1647 
1648 // TOSA resize with width or height of 1 may be broadcasted to a wider
1649 // dimension. This is done by materializing a new tosa.resize without
1650 // the broadcasting behavior, and an explicit broadcast afterwards.
1651 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1652 public:
1654 
1655  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1656  PatternRewriter &rewriter) const final {
1657  Location loc = op.getLoc();
1658  ImplicitLocOpBuilder builder(loc, rewriter);
1659  auto input = op.getInput();
1660  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1661  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1662 
1663  if (!inputTy || !resultTy)
1664  return rewriter.notifyMatchFailure(op,
1665  "requires ranked input/output types");
1666 
1667  auto batch = inputTy.getDimSize(0);
1668  auto channels = inputTy.getDimSize(3);
1669  auto inputH = inputTy.getDimSize(1);
1670  auto inputW = inputTy.getDimSize(2);
1671  auto outputH = resultTy.getDimSize(1);
1672  auto outputW = resultTy.getDimSize(2);
1673 
1674  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1675  return rewriter.notifyMatchFailure(
1676  op, "tosa.resize has no broadcasting behavior");
1677 
1678  // For any dimension that is broadcastable we generate a width of 1
1679  // on the output.
1680  llvm::SmallVector<int64_t> resizeShape;
1681  resizeShape.push_back(batch);
1682  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1683  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1684  resizeShape.push_back(channels);
1685 
1686  auto resizeTy = resultTy.clone(resizeShape);
1687  auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
1688  op.getOffset(), op.getBorder(),
1689  op.getMode());
1690 
1691  // Collapse an unit result dims.
1692  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1693  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1694  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1695  if (inputH != 1)
1696  reassociationMap.push_back({});
1697  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1698  if (inputW != 1)
1699  reassociationMap.push_back({});
1700  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1701 
1702  llvm::SmallVector<int64_t> collapseShape = {batch};
1703  if (inputH != 1)
1704  collapseShape.push_back(outputH);
1705  if (inputW != 1)
1706  collapseShape.push_back(outputW);
1707  collapseShape.push_back(channels);
1708 
1709  auto collapseTy = resultTy.clone(collapseShape);
1710  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1711  reassociationMap);
1712 
1713  // Broadcast the collapsed shape to the output result.
1714  llvm::SmallVector<Value> outputDynSize;
1715  if (inputTy.isDynamicDim(0))
1716  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1717  if (inputTy.isDynamicDim(3))
1718  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1719 
1720  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1721  utils::IteratorType::parallel);
1722  Value empty = builder.create<tensor::EmptyOp>(
1723  resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1724 
1725  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1726  if (inputH != 1)
1727  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1728  if (inputW != 1)
1729  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1730  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1731 
1732  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1733  inputExprs, rewriter.getContext());
1734 
1735  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1736  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1737  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1738  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1739  [=](OpBuilder &b, Location loc, ValueRange args) {
1740  Value value = args[0];
1741  b.create<linalg::YieldOp>(loc, value);
1742  });
1743 
1744  return success();
1745  }
1746 };
1747 
1748 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1749 public:
1751 
1752  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1753  PatternRewriter &rewriter) const final {
1754  Location loc = op.getLoc();
1755  ImplicitLocOpBuilder b(loc, rewriter);
1756  auto input = op.getInput();
1757  auto inputTy = cast<ShapedType>(input.getType());
1758  auto resultTy = cast<ShapedType>(op.getType());
1759  auto resultETy = resultTy.getElementType();
1760 
1761  bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1762  auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1763 
1764  auto imageH = inputTy.getShape()[1];
1765  auto imageW = inputTy.getShape()[2];
1766 
1767  auto dynamicDimsOr =
1768  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1769  if (!dynamicDimsOr.has_value())
1770  return rewriter.notifyMatchFailure(
1771  op, "unable to get dynamic dimensions of tosa.resize");
1772 
1773  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1774  return rewriter.notifyMatchFailure(
1775  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1776 
1777  SmallVector<AffineMap, 2> affineMaps = {
1778  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1779  auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1780  *dynamicDimsOr);
1781  auto genericOp = b.create<linalg::GenericOp>(
1782  resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1783  getNParallelLoopsAttrs(resultTy.getRank()));
1784  Value resize = genericOp.getResult(0);
1785 
1786  {
1787  OpBuilder::InsertionGuard regionGuard(b);
1788  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1789  TypeRange({resultETy}), loc);
1790  Value batch = b.create<linalg::IndexOp>(0);
1791  Value y = b.create<linalg::IndexOp>(1);
1792  Value x = b.create<linalg::IndexOp>(2);
1793  Value channel = b.create<linalg::IndexOp>(3);
1794 
1795  Value zeroI32 =
1796  b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1797  Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1798  Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1799  Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1800 
1801  Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1802  Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1803 
1804  SmallVector<int64_t> scale, offset, border;
1805  if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
1806  !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
1807  !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
1808  return rewriter.notifyMatchFailure(
1809  op, "tosa.resize scale/offset/border should have compile time "
1810  "constant values.");
1811  }
1812 
1813  Value yScaleN, yScaleD, xScaleN, xScaleD;
1814  yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1815  yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1816  xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1817  xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1818 
1819  Value yOffset, xOffset, yBorder, xBorder;
1820  yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1821  xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1822  yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1823  xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1824 
1825  // Compute the ix and dx values for both the X and Y dimensions.
1826  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1827  Value scaleN, Value scaleD, Value offset,
1828  int size, ImplicitLocOpBuilder &b) {
1829  if (size == 1) {
1830  index = zeroI32;
1831  delta = zeroFp;
1832  return;
1833  }
1834  // x = x * scale_d + offset;
1835  // ix = floor(x / scale_n)
1836  Value val = b.create<arith::MulIOp>(in, scaleD);
1837  val = b.create<arith::AddIOp>(val, offset);
1838  index = b.create<arith::FloorDivSIOp>(val, scaleN);
1839 
1840  // rx = x % scale_n
1841  // dx = rx / scale_n
1842  Value r = b.create<arith::RemSIOp>(val, scaleN);
1843  Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1844  Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1845  delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1846  };
1847 
1848  // Compute the ix and dx values for the X and Y dimensions - int case.
1849  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1850  Value scaleN, Value scaleD, Value offset,
1851  int size, ImplicitLocOpBuilder &b) {
1852  if (size == 1) {
1853  index = zeroI32;
1854  delta = zeroI32;
1855  return;
1856  }
1857  // x = x * scale_d + offset;
1858  // ix = floor(x / scale_n)
1859  // dx = x - ix * scale_n;
1860  Value val = b.create<arith::MulIOp>(in, scaleD);
1861  val = b.create<arith::AddIOp>(val, offset);
1862  index = b.create<arith::DivSIOp>(val, scaleN);
1863  delta = b.create<arith::MulIOp>(index, scaleN);
1864  delta = b.create<arith::SubIOp>(val, delta);
1865  };
1866 
1867  Value ix, iy, dx, dy;
1868  if (floatingPointMode) {
1869  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1870  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1871  } else {
1872  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1873  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1874  }
1875 
1876  if (op.getMode() == "NEAREST_NEIGHBOR") {
1877  auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1878 
1879  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1880  Value max, int size,
1881  ImplicitLocOpBuilder &b) -> Value {
1882  if (size == 1) {
1883  return b.create<arith::ConstantIndexOp>(0);
1884  }
1885 
1886  Value pred;
1887  if (floatingPointMode) {
1888  auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1889  pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1890  } else {
1891  Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1892  pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1893  dvalDouble, scale);
1894  }
1895 
1896  auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1897  val = b.create<arith::AddIOp>(val, offset);
1898  val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
1899  return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1900  };
1901 
1902  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1903  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1904 
1905  Value result = b.create<tensor::ExtractOp>(
1906  input, ValueRange{batch, iy, ix, channel});
1907 
1908  b.create<linalg::YieldOp>(result);
1909  } else {
1910  // The mode here must be BILINEAR.
1911  assert(op.getMode() == "BILINEAR");
1912 
1913  auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1914 
1915  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1917  val0 = in;
1918  val1 = b.create<arith::AddIOp>(val0, oneVal);
1919  val0 =
1920  clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1921  val1 =
1922  clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
1923  val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1924  val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1925  };
1926 
1927  // Linalg equivalent to the section below:
1928  // int16_t iy0 = apply_max(iy, 0);
1929  // int16_t iy1 = apply_min(iy + 1, IH - 1);
1930  // int16_t ix0 = apply_max(ix, 0);
1931  // int16_t ix1 = apply_min(ix + 1, IW - 1);
1932  Value x0, x1, y0, y1;
1933  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1934  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1935 
1936  Value y0x0 = b.create<tensor::ExtractOp>(
1937  input, ValueRange{batch, y0, x0, channel});
1938  Value y0x1 = b.create<tensor::ExtractOp>(
1939  input, ValueRange{batch, y0, x1, channel});
1940  Value y1x0 = b.create<tensor::ExtractOp>(
1941  input, ValueRange{batch, y1, x0, channel});
1942  Value y1x1 = b.create<tensor::ExtractOp>(
1943  input, ValueRange{batch, y1, x1, channel});
1944 
1945  if (floatingPointMode) {
1946  auto oneVal =
1947  b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1948  auto interpolate = [&](Value val0, Value val1, Value delta,
1949  int inputSize,
1950  ImplicitLocOpBuilder &b) -> Value {
1951  if (inputSize == 1)
1952  return val0;
1953  Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1954  Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1955  Value mul1 = b.create<arith::MulFOp>(val1, delta);
1956  return b.create<arith::AddFOp>(mul0, mul1);
1957  };
1958 
1959  // Linalg equivalent to the section below:
1960  // topAcc = v00 * (unit_x - dx);
1961  // topAcc += v01 * dx;
1962  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1963 
1964  // Linalg equivalent to the section below:
1965  // bottomAcc = v10 * (unit_x - dx);
1966  // bottomAcc += v11 * dx;
1967  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1968 
1969  // Linalg equivalent to the section below:
1970  // result = topAcc * (unit_y - dy) + bottomAcc * dy
1971  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1972  b.create<linalg::YieldOp>(result);
1973  } else {
1974  // Perform in quantized space.
1975  y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1976  y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1977  y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1978  y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1979 
1980  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1981  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1982  dx = b.create<arith::ExtSIOp>(resultETy, dx);
1983  dy = b.create<arith::ExtSIOp>(resultETy, dy);
1984  }
1985 
1986  Value yScaleNExt = yScaleN;
1987  Value xScaleNExt = xScaleN;
1988 
1989  const int64_t scaleBitwidth =
1990  xScaleN.getType().getIntOrFloatBitWidth();
1991  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1992  yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1993  xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1994  }
1995 
1996  auto interpolate = [](Value val0, Value val1, Value weight1,
1997  Value scale, int inputSize,
1998  ImplicitLocOpBuilder &b) -> Value {
1999  if (inputSize == 1)
2000  return b.create<arith::MulIOp>(val0, scale);
2001  Value weight0 = b.create<arith::SubIOp>(scale, weight1);
2002  Value mul0 = b.create<arith::MulIOp>(val0, weight0);
2003  Value mul1 = b.create<arith::MulIOp>(val1, weight1);
2004  return b.create<arith::AddIOp>(mul0, mul1);
2005  };
2006 
2007  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2008  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2009  Value result =
2010  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2011  b.create<linalg::YieldOp>(result);
2012  }
2013  }
2014  }
2015 
2016  rewriter.replaceOp(op, resize);
2017  return success();
2018  }
2019 };
2020 
2021 // At the codegen level any identity operations should be removed. Any cases
2022 // where identity is load-bearing (e.g. cross device computation) should be
2023 // handled before lowering to codegen.
2024 template <typename SrcOp>
2025 class IdentityNConverter : public OpRewritePattern<SrcOp> {
2026 public:
2028 
2029  LogicalResult matchAndRewrite(SrcOp op,
2030  PatternRewriter &rewriter) const final {
2031  rewriter.replaceOp(op, op.getOperation()->getOperands());
2032  return success();
2033  }
2034 };
2035 
2036 template <typename SrcOp>
2037 class ReduceConverter : public OpRewritePattern<SrcOp> {
2038 public:
2040 
2041  LogicalResult matchAndRewrite(SrcOp reduceOp,
2042  PatternRewriter &rewriter) const final {
2043  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2044  }
2045 };
2046 
2047 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2048 public:
2050 
2051  LogicalResult matchAndRewrite(tosa::ReverseOp op,
2052  PatternRewriter &rewriter) const final {
2053  auto loc = op.getLoc();
2054  Value input = op.getInput1();
2055  auto inputTy = cast<ShapedType>(input.getType());
2056  auto resultTy = cast<ShapedType>(op.getType());
2057  auto axis = op.getAxis();
2058 
2059  SmallVector<Value> dynDims;
2060  for (int i = 0; i < inputTy.getRank(); i++) {
2061  if (inputTy.isDynamicDim(i)) {
2062  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2063  }
2064  }
2065 
2066  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
2067 
2068  // First fill the output buffer with the init value.
2069  auto emptyTensor = rewriter
2070  .create<tensor::EmptyOp>(loc, inputTy.getShape(),
2071  inputTy.getElementType(),
2072  ArrayRef<Value>({dynDims}))
2073  .getResult();
2074  SmallVector<AffineMap, 2> affineMaps = {
2075  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2076 
2077  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2078  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
2079  getNParallelLoopsAttrs(resultTy.getRank()),
2080  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2081  llvm::SmallVector<Value> indices;
2082  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2083  Value index =
2084  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
2085  if (i == axis) {
2086  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
2087  auto sizeMinusOne =
2088  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
2089  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
2090  index);
2091  }
2092 
2093  indices.push_back(index);
2094  }
2095 
2096  auto extract = nestedBuilder.create<tensor::ExtractOp>(
2097  nestedLoc, input, indices);
2098  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
2099  extract.getResult());
2100  });
2101  return success();
2102  }
2103 };
2104 
2105 // This converter translate a tile operation to a reshape, broadcast, reshape.
2106 // The first reshape minimally expands each tiled dimension to include a
2107 // proceding size-1 dim. This dim is then broadcasted to the appropriate
2108 // multiple.
2109 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2111 
2112  LogicalResult
2113  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2114  ConversionPatternRewriter &rewriter) const override {
2115  auto loc = op.getLoc();
2116  auto input = op.getInput1();
2117  auto inputTy = cast<ShapedType>(input.getType());
2118  auto inputShape = inputTy.getShape();
2119  auto resultTy = cast<ShapedType>(op.getType());
2120  auto elementTy = inputTy.getElementType();
2121  int64_t rank = inputTy.getRank();
2122 
2123  SmallVector<int64_t> multiples;
2124  if (failed(op.getConstantMultiples(multiples)))
2125  return failure();
2126 
2127  // Broadcast the newly added dimensions to their appropriate multiple.
2128  SmallVector<int64_t, 2> genericShape;
2129  for (int i = 0; i < rank; i++) {
2130  int64_t dim = multiples[i];
2131  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2132  genericShape.push_back(inputShape[i]);
2133  }
2134 
2135  SmallVector<Value> dynDims;
2136  for (int i = 0; i < inputTy.getRank(); i++) {
2137  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2138  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2139  }
2140  }
2141 
2142  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
2143  op.getLoc(), genericShape, elementTy, dynDims);
2144 
2145  // We needs to map the input shape to the non-broadcasted dimensions.
2146  SmallVector<AffineExpr, 4> dimExprs;
2147  dimExprs.reserve(rank);
2148  for (unsigned i = 0; i < rank; ++i)
2149  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2150 
2151  auto readAffineMap =
2152  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2153  rewriter.getContext());
2154 
2155  SmallVector<AffineMap, 2> affineMaps = {
2156  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2157 
2158  auto genericOp = rewriter.create<linalg::GenericOp>(
2159  loc, RankedTensorType::get(genericShape, elementTy), input,
2160  ValueRange{emptyTensor}, affineMaps,
2161  getNParallelLoopsAttrs(genericShape.size()),
2162  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2163  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
2164  });
2165 
2166  auto shapeValue = getTosaConstShape(
2167  rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
2168  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2169  op, resultTy, genericOp.getResult(0), shapeValue);
2170  return success();
2171  }
2172 };
2173 
2174 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2175 // op, producing two output buffers.
2176 //
2177 // The first output buffer contains the index of the found maximum value. It is
2178 // initialized to 0 and is resulting integer type.
2179 //
2180 // The second output buffer contains the maximum value found. It is initialized
2181 // to the minimum representable value of the input element type. After being
2182 // populated by indexed_generic, this buffer is disgarded as only the index is
2183 // requested.
2184 //
2185 // The indexed_generic op updates both the maximum value and index if the
2186 // current value exceeds the running max.
2187 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2188 public:
2190 
2191  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2192  PatternRewriter &rewriter) const final {
2193  auto loc = argmaxOp.getLoc();
2194  Value input = argmaxOp.getInput();
2195  auto inputTy = cast<ShapedType>(input.getType());
2196  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2197  auto inElementTy = inputTy.getElementType();
2198  auto outElementTy = resultTy.getElementType();
2199  int axis = argmaxOp.getAxis();
2200  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2201 
2202  if (!isa<IntegerType>(outElementTy))
2203  return rewriter.notifyMatchFailure(
2204  argmaxOp,
2205  "tosa.arg_max to linalg.* requires integer-like result type");
2206 
2207  SmallVector<Value> dynDims;
2208  for (int i = 0; i < inputTy.getRank(); i++) {
2209  if (inputTy.isDynamicDim(i) && i != axis) {
2210  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2211  }
2212  }
2213 
2214  // First fill the output buffer for the index.
2215  auto emptyTensorIdx = rewriter
2216  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2217  outElementTy, dynDims)
2218  .getResult();
2219  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
2220  loc, rewriter.getIntegerAttr(outElementTy, 0));
2221  auto filledTensorIdx =
2222  rewriter
2223  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
2224  ValueRange{emptyTensorIdx})
2225  .result();
2226 
2227  // Second fill the output buffer for the running max.
2228  auto emptyTensorMax = rewriter
2229  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2230  inElementTy, dynDims)
2231  .getResult();
2232  auto fillValueMaxAttr =
2233  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2234 
2235  if (!fillValueMaxAttr)
2236  return rewriter.notifyMatchFailure(
2237  argmaxOp, "unsupported tosa.argmax element type");
2238 
2239  auto fillValueMax =
2240  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2241  auto filledTensorMax =
2242  rewriter
2243  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2244  ValueRange{emptyTensorMax})
2245  .result();
2246 
2247  // We need to reduce along the arg-max axis, with parallel operations along
2248  // the rest.
2250  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2251  iteratorTypes[axis] = utils::IteratorType::reduction;
2252 
2253  SmallVector<AffineExpr, 2> srcExprs;
2254  SmallVector<AffineExpr, 2> dstExprs;
2255  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2256  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2257  if (axis != i)
2258  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2259  }
2260 
2261  bool didEncounterError = false;
2262  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2263  rewriter.getContext());
2264  auto linalgOp = rewriter.create<linalg::GenericOp>(
2265  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2266  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2267  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2268  ValueRange blockArgs) {
2269  auto newValue = blockArgs[0];
2270  auto oldIndex = blockArgs[1];
2271  auto oldValue = blockArgs[2];
2272 
2273  Value newIndex = rewriter.create<arith::IndexCastOp>(
2274  nestedLoc, oldIndex.getType(),
2275  rewriter.create<linalg::IndexOp>(loc, axis));
2276 
2277  Value predicate;
2278  if (isa<FloatType>(inElementTy)) {
2279  if (argmaxOp.getNanMode() == "IGNORE") {
2280  // Only update index & max value for non NaN values. If all
2281  // values are NaNs, the initial index will be return which is 0.
2282  predicate = rewriter.create<arith::CmpFOp>(
2283  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2284  } else {
2285  // Update max value if either of the following is true:
2286  // - new value is bigger
2287  // - cur max is not NaN and new value is NaN
2288  Value gt = rewriter.create<arith::CmpFOp>(
2289  nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
2290  Value oldNonNaN = rewriter.create<arith::CmpFOp>(
2291  nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
2292  predicate = rewriter.create<arith::AndIOp>(
2293  nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2294  }
2295  } else if (isa<IntegerType>(inElementTy)) {
2296  predicate = rewriter.create<arith::CmpIOp>(
2297  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2298  } else {
2299  didEncounterError = true;
2300  return;
2301  }
2302 
2303  auto resultMax = rewriter.create<arith::SelectOp>(
2304  nestedLoc, predicate, newValue, oldValue);
2305  auto resultIndex = rewriter.create<arith::SelectOp>(
2306  nestedLoc, predicate, newIndex, oldIndex);
2307  nestedBuilder.create<linalg::YieldOp>(
2308  nestedLoc, ValueRange({resultIndex, resultMax}));
2309  });
2310 
2311  if (didEncounterError)
2312  return rewriter.notifyMatchFailure(
2313  argmaxOp, "unsupported tosa.argmax element type");
2314 
2315  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2316  return success();
2317  }
2318 };
2319 
2320 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2321 public:
2323  LogicalResult
2324  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2325  ConversionPatternRewriter &rewriter) const final {
2326  auto input = adaptor.getOperands()[0];
2327  auto indices = adaptor.getOperands()[1];
2328 
2329  auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2330  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2331  if (!valuesTy || !resultTy)
2332  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2333 
2334  auto dynamicDims = inferDynamicDimsForGather(
2335  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2336 
2337  auto resultElementTy = resultTy.getElementType();
2338 
2339  auto loc = op.getLoc();
2340  auto emptyTensor =
2341  rewriter
2342  .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2343  dynamicDims)
2344  .getResult();
2345 
2346  SmallVector<AffineMap, 2> affineMaps = {
2348  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2349  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2350  rewriter.getContext()),
2351  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2352 
2353  auto genericOp = rewriter.create<linalg::GenericOp>(
2354  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2355  ValueRange{emptyTensor}, affineMaps,
2356  getNParallelLoopsAttrs(resultTy.getRank()),
2357  [&](OpBuilder &b, Location loc, ValueRange args) {
2358  auto indexValue = args[0];
2359  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2360  Value index1 = rewriter.create<arith::IndexCastOp>(
2361  loc, rewriter.getIndexType(), indexValue);
2362  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2363  Value extract = rewriter.create<tensor::ExtractOp>(
2364  loc, input, ValueRange{index0, index1, index2});
2365  rewriter.create<linalg::YieldOp>(loc, extract);
2366  });
2367  rewriter.replaceOp(op, genericOp.getResult(0));
2368  return success();
2369  }
2370 
2371  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2372  Location loc,
2373  Value values,
2374  Value indices) {
2375  llvm::SmallVector<Value> results;
2376 
2377  auto addDynamicDimension = [&](Value source, int64_t dim) {
2378  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2379  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2380  results.push_back(dimValue);
2381  };
2382 
2383  addDynamicDimension(values, 0);
2384  addDynamicDimension(indices, 1);
2385  addDynamicDimension(values, 2);
2386  return results;
2387  }
2388 };
2389 
2390 // Lowerings the TableOp to a series of gathers and numerica operations. This
2391 // includes interpolation between the high/low values. For the I8 varient, this
2392 // simplifies to a single gather operation.
2393 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2394 public:
2396 
2397  LogicalResult matchAndRewrite(tosa::TableOp op,
2398  PatternRewriter &rewriter) const final {
2399  auto loc = op.getLoc();
2400  Value input = op.getInput1();
2401  Value table = op.getTable();
2402  auto inputTy = cast<ShapedType>(input.getType());
2403  auto tableTy = cast<ShapedType>(table.getType());
2404  auto resultTy = cast<ShapedType>(op.getType());
2405 
2406  auto inputElementTy = inputTy.getElementType();
2407  auto tableElementTy = tableTy.getElementType();
2408  auto resultElementTy = resultTy.getElementType();
2409 
2410  SmallVector<Value> dynDims;
2411  for (int i = 0; i < resultTy.getRank(); ++i) {
2412  if (inputTy.isDynamicDim(i)) {
2413  dynDims.push_back(
2414  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2415  }
2416  }
2417 
2418  auto emptyTensor = rewriter
2419  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2420  resultElementTy, dynDims)
2421  .getResult();
2422 
2423  SmallVector<AffineMap, 2> affineMaps = {
2424  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2425  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2426 
2427  auto genericOp = rewriter.create<linalg::GenericOp>(
2428  loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2429  getNParallelLoopsAttrs(resultTy.getRank()));
2430  rewriter.replaceOp(op, genericOp.getResult(0));
2431 
2432  {
2433  OpBuilder::InsertionGuard regionGuard(rewriter);
2434  Block *block = rewriter.createBlock(
2435  &genericOp.getRegion(), genericOp.getRegion().end(),
2436  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2437 
2438  auto inputValue = block->getArgument(0);
2439  rewriter.setInsertionPointToStart(block);
2440  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2441  resultElementTy.isInteger(8)) {
2442  Value index = rewriter.create<arith::IndexCastOp>(
2443  loc, rewriter.getIndexType(), inputValue);
2444  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2445  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2446  index, offset);
2447  Value extract =
2448  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2449  rewriter.create<linalg::YieldOp>(loc, extract);
2450  return success();
2451  }
2452 
2453  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2454  resultElementTy.isInteger(32)) {
2455  Value extend = rewriter.create<arith::ExtSIOp>(
2456  loc, rewriter.getI32Type(), inputValue);
2457 
2458  auto offset = rewriter.create<arith::ConstantOp>(
2459  loc, rewriter.getI32IntegerAttr(32768));
2460  auto seven = rewriter.create<arith::ConstantOp>(
2461  loc, rewriter.getI32IntegerAttr(7));
2462  auto one = rewriter.create<arith::ConstantOp>(
2463  loc, rewriter.getI32IntegerAttr(1));
2464  auto b1111111 = rewriter.create<arith::ConstantOp>(
2465  loc, rewriter.getI32IntegerAttr(127));
2466 
2467  // Compute the index and fractional part from the input value:
2468  // value = value + 32768
2469  // index = value >> 7;
2470  // fraction = 0x01111111 & value
2471  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2472  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2473  Value fraction =
2474  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2475 
2476  // Extract the base and next values from the table.
2477  // base = (int32_t) table[index];
2478  // next = (int32_t) table[index + 1];
2479  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2480 
2481  index = rewriter.create<arith::IndexCastOp>(
2482  loc, rewriter.getIndexType(), index);
2483  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2484  loc, rewriter.getIndexType(), indexPlusOne);
2485 
2486  Value base =
2487  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2488  Value next = rewriter.create<tensor::ExtractOp>(
2489  loc, table, ValueRange{indexPlusOne});
2490 
2491  base =
2492  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2493  next =
2494  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2495 
2496  // Use the fractional part to interpolate between the input values:
2497  // result = (base << 7) + (next - base) * fraction
2498  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2499  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2500  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2501  Value result =
2502  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2503 
2504  rewriter.create<linalg::YieldOp>(loc, result);
2505 
2506  return success();
2507  }
2508  }
2509 
2510  return rewriter.notifyMatchFailure(
2511  op, "unable to create body for tosa.table op");
2512  }
2513 };
2514 
2515 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2517 
2518  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2519 
2520  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2521  OpFoldResult ofr) {
2522  auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2523  auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2524 
2525  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2526  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2527  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2528  return getAsOpFoldResult(plusOne);
2529  }
2530 
2531  static RankedTensorType
2532  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2533  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2534  // Get [N, H, W]
2535  auto dims = tensor::getMixedSizes(builder, loc, input);
2536 
2537  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2538  // output tensors.
2539  dims[2] = halfPlusOne(builder, loc, dims[2]);
2540 
2541  llvm::SmallVector<int64_t, 3> staticSizes;
2542  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2543 
2544  auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2545  return RankedTensorType::get(staticSizes, elementType);
2546  }
2547 
2548  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2549  RankedTensorType type,
2550  llvm::ArrayRef<Value> dynamicSizes) {
2551  auto emptyTensor =
2552  rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2553  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2554  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2555  auto filledTensor = rewriter
2556  .create<linalg::FillOp>(loc, ValueRange{fillValue},
2557  ValueRange{emptyTensor})
2558  .result();
2559  return filledTensor;
2560  }
2561 
2562  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2563  FloatType type, Value value) {
2564  auto integerVal = builder.create<arith::IndexCastUIOp>(
2565  loc,
2566  type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2567  : builder.getI32Type(),
2568  value);
2569 
2570  return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2571  }
2572 
2573  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2574  FloatType type, int64_t index) {
2575  auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2576  return castIndexToFloat(builder, loc, type, indexVal);
2577  }
2578 
2579  template <typename... Args>
2580  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2581  Args... args) {
2582  return {builder.getAffineDimExpr(args)...};
2583  }
2584 
2585  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2586  PatternRewriter &rewriter) const override {
2587  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2588  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2589  return rewriter.notifyMatchFailure(rfft2d,
2590  "only supports ranked tensors");
2591  }
2592 
2593  auto loc = rfft2d.getLoc();
2594  auto input = rfft2d.getInputReal();
2595  auto elementType =
2596  dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2597  if (!elementType)
2598  return rewriter.notifyMatchFailure(rfft2d,
2599  "only supports float element types");
2600 
2601  // Compute the output type and set of dynamic sizes
2602  llvm::SmallVector<Value> dynamicSizes;
2603  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2604 
2605  // Iterator types for the linalg.generic implementation
2607  utils::IteratorType::parallel, utils::IteratorType::parallel,
2608  utils::IteratorType::parallel, utils::IteratorType::reduction,
2609  utils::IteratorType::reduction};
2610 
2611  // Inputs/outputs to the linalg.generic implementation
2612  llvm::SmallVector<Value> genericOpInputs = {input};
2613  llvm::SmallVector<Value> genericOpOutputs = {
2614  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2615  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2616 
2617  // Indexing maps for input and output tensors
2618  auto indexingMaps = AffineMap::inferFromExprList(
2619  llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2620  affineDimsExpr(rewriter, 0, 1, 2),
2621  affineDimsExpr(rewriter, 0, 1, 2)},
2622  rewriter.getContext());
2623 
2624  // Width and height dimensions of the original input.
2625  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2626  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2627 
2628  // Constants and dimension sizes
2629  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2630  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2631  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2632  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2633 
2634  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2635  Value valReal = args[0];
2636  Value sumReal = args[1];
2637  Value sumImag = args[2];
2638 
2639  // Indices for angle computation
2640  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2641  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2642  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2643  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2644 
2645  // Calculating angle without integer parts of components as sin/cos are
2646  // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2647  // / W);
2648  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2649  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2650 
2651  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2652  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2653 
2654  auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2655  auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2656 
2657  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2658  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2659  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2660  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2661 
2662  // realComponent = valReal * cos(angle)
2663  // imagComponent = valReal * sin(angle)
2664  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2665  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2666  auto realComponent =
2667  builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2668  auto imagComponent =
2669  builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2670 
2671  // outReal = sumReal + realComponent
2672  // outImag = sumImag - imagComponent
2673  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2674  auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2675 
2676  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2677  };
2678 
2679  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2680  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2681  indexingMaps, iteratorTypes, buildBody);
2682 
2683  return success();
2684  }
2685 };
2686 
2687 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2689 
2690  LogicalResult matchAndRewrite(FFT2dOp fft2d,
2691  PatternRewriter &rewriter) const override {
2692  if (!llvm::all_of(fft2d->getOperandTypes(),
2693  RFFT2dConverter::isRankedTensor) ||
2694  !llvm::all_of(fft2d->getResultTypes(),
2695  RFFT2dConverter::isRankedTensor)) {
2696  return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2697  }
2698 
2699  Location loc = fft2d.getLoc();
2700  Value input_real = fft2d.getInputReal();
2701  Value input_imag = fft2d.getInputImag();
2702  BoolAttr inverse = fft2d.getInverseAttr();
2703 
2704  auto real_el_ty = cast<FloatType>(
2705  cast<ShapedType>(input_real.getType()).getElementType());
2706  [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2707  cast<ShapedType>(input_imag.getType()).getElementType());
2708 
2709  assert(real_el_ty == imag_el_ty);
2710 
2711  // Compute the output type and set of dynamic sizes
2712  SmallVector<Value> dynamicSizes;
2713 
2714  // Get [N, H, W]
2715  auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2716 
2717  SmallVector<int64_t, 3> staticSizes;
2718  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2719 
2720  auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2721 
2722  // Iterator types for the linalg.generic implementation
2723  SmallVector<utils::IteratorType, 5> iteratorTypes = {
2724  utils::IteratorType::parallel, utils::IteratorType::parallel,
2725  utils::IteratorType::parallel, utils::IteratorType::reduction,
2726  utils::IteratorType::reduction};
2727 
2728  // Inputs/outputs to the linalg.generic implementation
2729  SmallVector<Value> genericOpInputs = {input_real, input_imag};
2730  SmallVector<Value> genericOpOutputs = {
2731  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2732  dynamicSizes),
2733  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2734  dynamicSizes)};
2735 
2736  // Indexing maps for input and output tensors
2737  auto indexingMaps = AffineMap::inferFromExprList(
2738  ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2739  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2740  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2741  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2742  rewriter.getContext());
2743 
2744  // Width and height dimensions of the original input.
2745  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2746  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2747 
2748  // Constants and dimension sizes
2749  auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2750  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2751  Value constH =
2752  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2753  Value constW =
2754  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2755 
2756  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2757  Value valReal = args[0];
2758  Value valImag = args[1];
2759  Value sumReal = args[2];
2760  Value sumImag = args[3];
2761 
2762  // Indices for angle computation
2763  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2764  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2765  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2766  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2767 
2768  // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2769  // ox) % W ) / W);
2770  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2771  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2772 
2773  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2774  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2775 
2776  auto iyRemFloat =
2777  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2778  auto ixRemFloat =
2779  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2780 
2781  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2782  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2783 
2784  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2785  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2786 
2787  if (inverse.getValue()) {
2788  angle = builder.create<arith::MulFOp>(
2789  loc, angle,
2790  rewriter.create<arith::ConstantOp>(
2791  loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2792  }
2793 
2794  // realComponent = val_real * cos(a) + val_imag * sin(a);
2795  // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2796  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2797  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2798 
2799  auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2800  auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2801  auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2802 
2803  auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2804  auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2805 
2806  auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2807 
2808  // outReal = sumReal + realComponent
2809  // outImag = sumImag - imagComponent
2810  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2811  auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2812 
2813  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2814  };
2815 
2816  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2817  fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2818  indexingMaps, iteratorTypes, buildBody);
2819 
2820  return success();
2821  }
2822 };
2823 
2824 } // namespace
2825 
2827  const TypeConverter &converter, RewritePatternSet *patterns) {
2828 
2829  // We have multiple resize coverters to handle degenerate cases.
2830  patterns->add<GenericResizeConverter>(patterns->getContext(),
2831  /*benefit=*/100);
2832  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2833  /*benefit=*/200);
2834  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2835  /*benefit=*/300);
2836 
2837  patterns->add<
2838  // clang-format off
2839  PointwiseConverter<tosa::AddOp>,
2840  PointwiseConverter<tosa::SubOp>,
2841  PointwiseConverter<tosa::MulOp>,
2842  PointwiseConverter<tosa::IntDivOp>,
2843  PointwiseConverter<tosa::NegateOp>,
2844  PointwiseConverter<tosa::PowOp>,
2845  PointwiseConverter<tosa::ReciprocalOp>,
2846  PointwiseConverter<tosa::RsqrtOp>,
2847  PointwiseConverter<tosa::LogOp>,
2848  PointwiseConverter<tosa::ExpOp>,
2849  PointwiseConverter<tosa::AbsOp>,
2850  PointwiseConverter<tosa::SinOp>,
2851  PointwiseConverter<tosa::CosOp>,
2852  PointwiseConverter<tosa::TanhOp>,
2853  PointwiseConverter<tosa::ErfOp>,
2854  PointwiseConverter<tosa::BitwiseAndOp>,
2855  PointwiseConverter<tosa::BitwiseOrOp>,
2856  PointwiseConverter<tosa::BitwiseNotOp>,
2857  PointwiseConverter<tosa::BitwiseXorOp>,
2858  PointwiseConverter<tosa::LogicalAndOp>,
2859  PointwiseConverter<tosa::LogicalNotOp>,
2860  PointwiseConverter<tosa::LogicalOrOp>,
2861  PointwiseConverter<tosa::LogicalXorOp>,
2862  PointwiseConverter<tosa::CastOp>,
2863  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2864  PointwiseConverter<tosa::LogicalRightShiftOp>,
2865  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2866  PointwiseConverter<tosa::ClzOp>,
2867  PointwiseConverter<tosa::SelectOp>,
2868  PointwiseConverter<tosa::GreaterOp>,
2869  PointwiseConverter<tosa::GreaterEqualOp>,
2870  PointwiseConverter<tosa::EqualOp>,
2871  PointwiseConverter<tosa::MaximumOp>,
2872  PointwiseConverter<tosa::MinimumOp>,
2873  PointwiseConverter<tosa::CeilOp>,
2874  PointwiseConverter<tosa::FloorOp>,
2875  PointwiseConverter<tosa::ClampOp>,
2876  PointwiseConverter<tosa::SigmoidOp>
2877  >(converter, patterns->getContext());
2878 
2879  patterns->add<
2880  IdentityNConverter<tosa::IdentityOp>,
2881  ReduceConverter<tosa::ReduceAllOp>,
2882  ReduceConverter<tosa::ReduceAnyOp>,
2883  ReduceConverter<tosa::ReduceMinOp>,
2884  ReduceConverter<tosa::ReduceMaxOp>,
2885  ReduceConverter<tosa::ReduceSumOp>,
2886  ReduceConverter<tosa::ReduceProductOp>,
2887  ArgMaxConverter,
2888  GatherConverter,
2889  RescaleConverter,
2890  ReverseConverter,
2891  RFFT2dConverter,
2892  FFT2dConverter,
2893  TableConverter,
2894  TileConverter>(patterns->getContext());
2895  // clang-format on
2896 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static bool operandsAndResultsRanked(Operation *operation)
const float * table
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:61
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319