MLIR  21.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 
34 #include <numeric>
35 
36 using namespace mlir;
37 using namespace mlir::tosa;
38 
39 template <typename T>
40 static arith::ConstantOp
41 createConstFromIntAttribute(Operation *op, const std::string &attrName,
42  Type requiredAttrType, OpBuilder &rewriter) {
43  auto castedN = static_cast<T>(
44  cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
45  return rewriter.create<arith::ConstantOp>(
46  op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
47 }
48 
50  Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51  ConversionPatternRewriter &rewriter) {
52  Location loc = op->getLoc();
53  auto elementTy =
54  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
55 
56  // tosa::AbsOp
57  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
58  return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
59 
60  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
61  auto zero = rewriter.create<arith::ConstantOp>(
62  loc, rewriter.getZeroAttr(elementTy));
63  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
64  return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
65  }
66 
67  // tosa::AddOp
68  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
69  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
70 
71  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
72  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
73 
74  // tosa::SubOp
75  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
76  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
77 
78  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
79  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
80 
81  // tosa::IntDivOp
82  if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
83  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
84 
85  // tosa::ReciprocalOp
86  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
87  auto one =
88  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
89  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
90  }
91 
92  // tosa::MulOp
93  if (isa<tosa::MulOp>(op)) {
94  auto shift_val = cast<tosa::MulOp>(op).getShift();
95  ElementsAttr shift_elem;
96  if (!shift_val.getImpl() ||
97  !matchPattern(shift_val, m_Constant(&shift_elem))) {
98  (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
99  }
100 
101  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
102 
103  if (isa<FloatType>(elementTy)) {
104  if (shift != 0) {
105  (void)rewriter.notifyMatchFailure(op,
106  "Cannot have shift value for float");
107  return nullptr;
108  }
109  return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
110  }
111 
112  if (isa<IntegerType>(elementTy)) {
113  Value a = args[0];
114  Value b = args[1];
115 
116  if (shift > 0) {
117  auto shiftConst =
118  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
119  if (!a.getType().isInteger(32))
120  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
121 
122  if (!b.getType().isInteger(32))
123  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
124 
125  auto result = rewriter.create<tosa::ApplyScaleOp>(
126  loc, rewriter.getI32Type(), a, b, shiftConst,
127  rewriter.getBoolAttr(false));
128 
129  if (elementTy.isInteger(32))
130  return result;
131 
132  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
133  }
134 
135  int aWidth = a.getType().getIntOrFloatBitWidth();
136  int bWidth = b.getType().getIntOrFloatBitWidth();
137  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
138 
139  if (aWidth < cWidth)
140  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
141  if (bWidth < cWidth)
142  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
143 
144  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
145  }
146  }
147 
148  // tosa::NegateOp
149  if (isa<tosa::NegateOp>(op)) {
150  if (isa<FloatType>(elementTy))
151  return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
152 
153  if (isa<IntegerType>(elementTy)) {
154  auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
155  auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
156 
157  const int64_t inZp =
158  inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
159  const int64_t outZp =
160  outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
161 
162  if (!inZp && !outZp) {
163  auto constant = rewriter.create<arith::ConstantOp>(
164  loc, IntegerAttr::get(elementTy, 0));
165  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
166  args[0]);
167  }
168 
169  // Compute the maximum value that can occur in the intermediate buffer.
170  const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
171  const int64_t zpAdd = inZp + outZp;
172  const int64_t maxValue =
173  APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
174  std::abs(zpAdd) + 1;
175 
176  // Convert that maximum value into the maximum bitwidth needed to
177  // represent it. We assume 48-bit numbers may be supported further in
178  // the pipeline.
179  int intermediateBitWidth = 64;
180  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
181  intermediateBitWidth = 16;
182  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
183  intermediateBitWidth = 32;
184  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
185  intermediateBitWidth = 48;
186  }
187 
188  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
189  Value zpAddValue = rewriter.create<arith::ConstantOp>(
190  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
191 
192  // The negation can be applied by doing:
193  // outputValue = inZp + outZp - inputValue
194  auto ext =
195  rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
196  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
197 
198  // Clamp to the negation range.
199  Value min = rewriter.create<arith::ConstantIntOp>(
200  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
201  intermediateType);
202  Value max = rewriter.create<arith::ConstantIntOp>(
203  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
204  intermediateType);
205  auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
206 
207  // Truncate to the final value.
208  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
209  }
210  }
211 
212  // tosa::BitwiseAndOp
213  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
214  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
215 
216  // tosa::BitwiseOrOp
217  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
218  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
219 
220  // tosa::BitwiseNotOp
221  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
222  auto allOnesAttr = rewriter.getIntegerAttr(
223  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
224  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
225  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
226  }
227 
228  // tosa::BitwiseXOrOp
229  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
230  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
231 
232  // tosa::LogicalLeftShiftOp
233  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
234  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
235 
236  // tosa::LogicalRightShiftOp
237  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
238  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
239 
240  // tosa::ArithmeticRightShiftOp
241  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
242  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
243  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
244  if (!round) {
245  return result;
246  }
247 
248  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
249  auto one =
250  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
251  auto zero =
252  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
253  auto i1one =
254  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
255 
256  // Checking that input2 != 0
257  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
258  loc, arith::CmpIPredicate::sgt, args[1], zero);
259 
260  // Checking for the last bit of input1 to be 1
261  auto subtract =
262  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
263  auto shifted =
264  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
265  ->getResults();
266  auto truncated =
267  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
268  auto isInputOdd =
269  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
270 
271  auto shouldRound = rewriter.create<arith::AndIOp>(
272  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
273  auto extended =
274  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
275  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
276  }
277 
278  // tosa::ClzOp
279  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
280  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
281  }
282 
283  // tosa::LogicalAnd
284  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
285  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
286 
287  // tosa::LogicalNot
288  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
289  auto one = rewriter.create<arith::ConstantOp>(
290  loc, rewriter.getIntegerAttr(elementTy, 1));
291  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
292  }
293 
294  // tosa::LogicalOr
295  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
296  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
297 
298  // tosa::LogicalXor
299  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
300  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
301 
302  // tosa::PowOp
303  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
304  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
305 
306  // tosa::RsqrtOp
307  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
308  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
309 
310  // tosa::LogOp
311  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
312  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
313 
314  // tosa::ExpOp
315  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
316  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
317 
318  // tosa::SinOp
319  if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
320  return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
321 
322  // tosa::CosOp
323  if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
324  return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
325 
326  // tosa::TanhOp
327  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
328  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
329 
330  // tosa::ErfOp
331  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
332  return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
333 
334  // tosa::GreaterOp
335  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
336  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
337  args[0], args[1]);
338 
339  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
340  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
341  args[0], args[1]);
342 
343  // tosa::GreaterEqualOp
344  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
345  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
346  args[0], args[1]);
347 
348  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
349  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
350  args[0], args[1]);
351 
352  // tosa::EqualOp
353  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
354  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
355  args[0], args[1]);
356 
357  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
358  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
359  args[0], args[1]);
360 
361  // tosa::SelectOp
362  if (isa<tosa::SelectOp>(op)) {
363  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
364  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
365  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
366  }
367 
368  // tosa::MaximumOp
369  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
370  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
371  }
372 
373  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
374  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
375  }
376 
377  // tosa::MinimumOp
378  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
379  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
380  }
381 
382  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
383  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
384  }
385 
386  // tosa::CeilOp
387  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
388  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
389 
390  // tosa::FloorOp
391  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
392  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
393 
394  // tosa::ClampOp
395  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
396  bool losesInfo = false;
397  APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
398  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
399  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
400  APFloat::rmNearestTiesToEven, &losesInfo);
401  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
402  APFloat::rmNearestTiesToEven, &losesInfo);
403  auto min = rewriter.create<arith::ConstantOp>(
404  loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
405  auto max = rewriter.create<arith::ConstantOp>(
406  loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
407  return clampFloatHelper(loc, args[0], min, max, rewriter);
408  }
409 
410  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
411  auto intTy = cast<IntegerType>(elementTy);
412  int64_t min =
413  cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
414  int64_t max =
415  cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
416 
417  int64_t minRepresentable = std::numeric_limits<int64_t>::min();
418  int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
419  if (intTy.isUnsignedInteger()) {
420  minRepresentable = 0;
421  if (intTy.getIntOrFloatBitWidth() <= 63) {
422  maxRepresentable =
423  (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
424  .getZExtValue();
425  }
426  } else if (intTy.getIntOrFloatBitWidth() <= 64) {
427  // Ensure that min & max fit into signed n-bit constants.
428  minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
429  .getSExtValue();
430  maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
431  .getSExtValue();
432  }
433  // Ensure that the bounds are representable as n-bit signed/unsigned
434  // integers.
435  min = std::max(min, minRepresentable);
436  max = std::max(max, minRepresentable);
437  min = std::min(min, maxRepresentable);
438  max = std::min(max, maxRepresentable);
439 
440  auto minVal = rewriter.create<arith::ConstantIntOp>(
441  loc, min, intTy.getIntOrFloatBitWidth());
442  auto maxVal = rewriter.create<arith::ConstantIntOp>(
443  loc, max, intTy.getIntOrFloatBitWidth());
444  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
445  intTy.isUnsignedInteger());
446  }
447 
448  // tosa::SigmoidOp
449  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
450  auto one =
451  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
452  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
453  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
454  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
455  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
456  }
457 
458  // tosa::CastOp
459  if (isa<tosa::CastOp>(op)) {
460  Type srcTy = elementTy;
461  Type dstTy = resultTypes.front();
462  bool bitExtend =
464 
465  if (srcTy == dstTy)
466  return args.front();
467 
468  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
469  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
470  std::nullopt);
471 
472  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
473  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
474  std::nullopt);
475 
476  // 1-bit integers need to be treated as signless.
477  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
478  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
479  std::nullopt);
480 
481  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
482  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
483  std::nullopt);
484 
485  // Unsigned integers need an unrealized cast so that they can be passed
486  // to UIToFP.
487  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
488  auto unrealizedCast =
489  rewriter
490  .create<UnrealizedConversionCastOp>(
491  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
492  args[0])
493  .getResult(0);
494  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
495  unrealizedCast);
496  }
497 
498  // All other si-to-fp conversions should be handled by SIToFP.
499  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
500  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
501  std::nullopt);
502 
503  // Casting to boolean, floats need to only be checked as not-equal to zero.
504  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
505  Value zero = rewriter.create<arith::ConstantOp>(
506  loc, rewriter.getFloatAttr(srcTy, 0.0));
507  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
508  args.front(), zero);
509  }
510 
511  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
512  auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
513 
514  const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
515  // Check whether neither int min nor int max can be represented in the
516  // input floating-point type due to too short exponent range.
517  if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
518  APFloat::semanticsMaxExponent(fltSemantics)) {
519  // Use cmp + select to replace infinites by int min / int max. Other
520  // integral values can be represented in the integer space.
521  auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
522  auto posInf = rewriter.create<arith::ConstantOp>(
523  loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
524  APFloat::getInf(fltSemantics)));
525  auto negInf = rewriter.create<arith::ConstantOp>(
526  loc, rewriter.getFloatAttr(
527  getElementTypeOrSelf(srcTy),
528  APFloat::getInf(fltSemantics, /*Negative=*/true)));
529  auto overflow = rewriter.create<arith::CmpFOp>(
530  loc, arith::CmpFPredicate::UEQ, rounded, posInf);
531  auto underflow = rewriter.create<arith::CmpFOp>(
532  loc, arith::CmpFPredicate::UEQ, rounded, negInf);
533  auto intMin = rewriter.create<arith::ConstantOp>(
534  loc, rewriter.getIntegerAttr(
535  getElementTypeOrSelf(dstTy),
536  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
537  auto intMax = rewriter.create<arith::ConstantOp>(
538  loc, rewriter.getIntegerAttr(
539  getElementTypeOrSelf(dstTy),
540  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
541  auto maxClamped =
542  rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
543  return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
544  maxClamped);
545  }
546 
547  auto intMinFP = rewriter.create<arith::ConstantOp>(
548  loc, rewriter.getFloatAttr(
549  getElementTypeOrSelf(srcTy),
550  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
551  .getSExtValue()));
552 
553  // Check whether the mantissa has enough bits to represent int max.
554  if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
555  dstTy.getIntOrFloatBitWidth() - 1) {
556  // Int min can also be represented since it is a power of two and thus
557  // consists of a single leading bit. Therefore we can clamp the input
558  // in the floating-point domain.
559 
560  auto intMaxFP = rewriter.create<arith::ConstantOp>(
561  loc, rewriter.getFloatAttr(
562  getElementTypeOrSelf(srcTy),
563  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
564  .getSExtValue()));
565 
566  Value clamped =
567  clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
568  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
569  }
570 
571  // Due to earlier check we know exponant range is big enough to represent
572  // int min. We can therefore rely on int max + 1 being representable as
573  // well because it's just int min with a positive sign. So clamp the min
574  // value and compare against that to select the max int value if needed.
575  auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
576  loc, rewriter.getFloatAttr(
577  getElementTypeOrSelf(srcTy),
578  static_cast<double>(
579  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
580  .getSExtValue()) +
581  1.0f));
582 
583  auto intMax = rewriter.create<arith::ConstantOp>(
584  loc, rewriter.getIntegerAttr(
585  getElementTypeOrSelf(dstTy),
586  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
587  auto minClampedFP =
588  rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
589  auto minClamped =
590  rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
591  auto overflow = rewriter.create<arith::CmpFOp>(
592  loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
593  return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
594  minClamped);
595  }
596 
597  // Casting to boolean, integers need to only be checked as not-equal to
598  // zero.
599  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
600  Value zero = rewriter.create<arith::ConstantIntOp>(
601  loc, 0, srcTy.getIntOrFloatBitWidth());
602  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
603  args.front(), zero);
604  }
605 
606  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
607  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
608  std::nullopt);
609 
610  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
611  return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
612  }
613  }
614 
615  (void)rewriter.notifyMatchFailure(
616  op, "unhandled op for linalg body calculation for elementwise op");
617  return nullptr;
618 }
619 
620 static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
621  int64_t rank) {
622  // No need to expand if we are already at the desired rank
623  auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
624  assert(tensorType && "expected a ranked tensor type");
625  int64_t tensorRank = tensorType.getRank();
626  int64_t numExtraDims = rank - tensorRank;
627  assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
628  if (!numExtraDims)
629  return tensor;
630 
631  // Compute reassociation indices
632  SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
633  int64_t index = 0;
634  if (tensorRank != 0) {
635  for (index = 0; index <= numExtraDims; index++)
636  reassociationIndices[0].push_back(index);
637  for (size_t position = 1; position < reassociationIndices.size();
638  position++)
639  reassociationIndices[position].push_back(index++);
640  }
641 
642  // Compute result type
643  SmallVector<int64_t> resultShape;
644  for (index = 0; index < numExtraDims; index++)
645  resultShape.push_back(1);
646  for (auto size : tensorType.getShape())
647  resultShape.push_back(size);
648  auto resultType =
649  RankedTensorType::get(resultShape, tensorType.getElementType());
650 
651  // Emit 'tensor.expand_shape' op
652  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
653  reassociationIndices);
654 }
655 
657  Location loc, ValueRange operands,
658  int64_t rank) {
659  return llvm::map_to_vector(operands, [&](Value operand) {
660  return expandRank(rewriter, loc, operand, rank);
661  });
662 }
663 
665 
666 // Emit an 'arith.constant' op for the given index if it has not been created
667 // yet, or return an existing constant. This will prevent an excessive creation
668 // of redundant constants, easing readability of emitted code for unit tests.
670  IndexPool &indexPool, int64_t index) {
671  auto [it, inserted] = indexPool.try_emplace(index);
672  if (inserted)
673  it->second =
674  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
675  return it->second;
676 }
677 
679  IndexPool &indexPool, Value tensor, int64_t index) {
680  auto indexValue = createIndex(rewriter, loc, indexPool, index);
681  return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
682 }
683 
685  IndexPool &indexPool, Value tensor,
686  int64_t index) {
687  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
688  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
689  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
690  if (shapedType.isDynamicDim(index))
691  return getTensorDim(rewriter, loc, indexPool, tensor, index);
692  return rewriter.getIndexAttr(shapedType.getDimSize(index));
693 }
694 
695 static bool operandsAndResultsRanked(Operation *operation) {
696  auto isRanked = [](Value value) {
697  return isa<RankedTensorType>(value.getType());
698  };
699  return llvm::all_of(operation->getOperands(), isRanked) &&
700  llvm::all_of(operation->getResults(), isRanked);
701 }
702 
703 // Compute the runtime dimension size for dimension 'dim' of the output by
704 // inspecting input 'operands', all of which are expected to have the same rank.
705 // This function returns a pair {targetSize, masterOperand}.
706 //
707 // The runtime size of the output dimension is returned either as a statically
708 // computed attribute or as a runtime SSA value.
709 //
710 // If the target size was inferred directly from one dominating operand, that
711 // operand is returned in 'masterOperand'. If the target size is inferred from
712 // multiple operands, 'masterOperand' is set to nullptr.
713 static std::pair<OpFoldResult, Value>
715  ValueRange operands, int64_t dim) {
716  // If any input operand contains a static size greater than 1 for this
717  // dimension, that is the target size. An occurrence of an additional static
718  // dimension greater than 1 with a different value is undefined behavior.
719  for (auto operand : operands) {
720  auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
721  if (!ShapedType::isDynamic(size) && size > 1)
722  return {rewriter.getIndexAttr(size), operand};
723  }
724 
725  // Filter operands with dynamic dimension
726  auto operandsWithDynamicDim =
727  llvm::filter_to_vector(operands, [&](Value operand) {
728  return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
729  });
730 
731  // If no operand has a dynamic dimension, it means all sizes were 1
732  if (operandsWithDynamicDim.empty())
733  return {rewriter.getIndexAttr(1), operands.front()};
734 
735  // Emit code that computes the runtime size for this dimension. If there is
736  // only one operand with a dynamic dimension, it is considered the master
737  // operand that determines the runtime size of the output dimension.
738  auto targetSize =
739  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
740  if (operandsWithDynamicDim.size() == 1)
741  return {targetSize, operandsWithDynamicDim[0]};
742 
743  // Calculate maximum size among all dynamic dimensions
744  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
745  auto nextSize =
746  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
747  targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
748  }
749  return {targetSize, nullptr};
750 }
751 
752 // Compute the runtime output size for all dimensions. This function returns
753 // a pair {targetShape, masterOperands}.
754 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
756  IndexPool &indexPool, ValueRange operands) {
757  assert(!operands.empty());
758  auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
759  SmallVector<OpFoldResult> targetShape;
760  SmallVector<Value> masterOperands;
761  for (auto dim : llvm::seq<int64_t>(0, rank)) {
762  auto [targetSize, masterOperand] =
763  computeTargetSize(rewriter, loc, indexPool, operands, dim);
764  targetShape.push_back(targetSize);
765  masterOperands.push_back(masterOperand);
766  }
767  return {targetShape, masterOperands};
768 }
769 
771  IndexPool &indexPool, Value operand,
772  int64_t dim, OpFoldResult targetSize,
773  Value masterOperand) {
774  // Nothing to do if this is a static dimension
775  auto rankedTensorType = cast<RankedTensorType>(operand.getType());
776  if (!rankedTensorType.isDynamicDim(dim))
777  return operand;
778 
779  // If the target size for this dimension was directly inferred by only taking
780  // this operand into account, there is no need to broadcast. This is an
781  // optimization that will prevent redundant control flow, and constitutes the
782  // main motivation for tracking "master operands".
783  if (operand == masterOperand)
784  return operand;
785 
786  // Affine maps for 'linalg.generic' op
787  auto rank = rankedTensorType.getRank();
788  SmallVector<AffineExpr> affineExprs;
789  for (auto index : llvm::seq<int64_t>(0, rank)) {
790  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
791  : rewriter.getAffineDimExpr(index);
792  affineExprs.push_back(affineExpr);
793  }
794  auto broadcastAffineMap =
795  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
796  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
797  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
798 
799  // Check if broadcast is necessary
800  auto one = createIndex(rewriter, loc, indexPool, 1);
801  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
802  auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
803  loc, arith::CmpIPredicate::eq, runtimeSize, one);
804 
805  // Emit 'then' region of 'scf.if'
806  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
807  // It is not safe to cache constants across regions.
808  // New constants could potentially violate dominance requirements.
809  IndexPool localPool;
810 
811  // Emit 'tensor.empty' op
812  SmallVector<OpFoldResult> outputTensorShape;
813  for (auto index : llvm::seq<int64_t>(0, rank)) {
814  auto size = index == dim ? targetSize
815  : getOrFoldTensorDim(rewriter, loc, localPool,
816  operand, index);
817  outputTensorShape.push_back(size);
818  }
819  Value outputTensor = opBuilder.create<tensor::EmptyOp>(
820  loc, outputTensorShape, rankedTensorType.getElementType());
821 
822  // Emit 'linalg.generic' op
823  auto resultTensor =
824  opBuilder
825  .create<linalg::GenericOp>(
826  loc, outputTensor.getType(), operand, outputTensor, affineMaps,
828  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
829  // Emit 'linalg.yield' op
830  opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
831  })
832  .getResult(0);
833 
834  // Cast to original operand type if necessary
835  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
836  loc, operand.getType(), resultTensor);
837 
838  // Emit 'scf.yield' op
839  opBuilder.create<scf::YieldOp>(loc, castResultTensor);
840  };
841 
842  // Emit 'else' region of 'scf.if'
843  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
844  opBuilder.create<scf::YieldOp>(loc, operand);
845  };
846 
847  // Emit 'scf.if' op
848  auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
849  emitThenRegion, emitElseRegion);
850  return ifOp.getResult(0);
851 }
852 
854  IndexPool &indexPool, Value operand,
855  ArrayRef<OpFoldResult> targetShape,
856  ArrayRef<Value> masterOperands) {
857  int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
858  assert((int64_t)targetShape.size() == rank);
859  assert((int64_t)masterOperands.size() == rank);
860  for (auto index : llvm::seq<int64_t>(0, rank))
861  operand =
862  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
863  targetShape[index], masterOperands[index]);
864  return operand;
865 }
866 
867 static SmallVector<Value>
869  IndexPool &indexPool, ValueRange operands,
870  ArrayRef<OpFoldResult> targetShape,
871  ArrayRef<Value> masterOperands) {
872  // No need to broadcast for unary operations
873  if (operands.size() == 1)
874  return operands;
875 
876  // Broadcast dynamic dimensions operand by operand
877  return llvm::map_to_vector(operands, [&](Value operand) {
878  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
879  targetShape, masterOperands);
880  });
881 }
882 
883 static LogicalResult
885  Operation *operation, ValueRange operands,
886  ArrayRef<OpFoldResult> targetShape,
887  const TypeConverter &converter) {
888  // Generate output tensor
889  auto resultType = cast_or_null<RankedTensorType>(
890  converter.convertType(operation->getResultTypes().front()));
891  if (!resultType) {
892  return rewriter.notifyMatchFailure(operation, "failed to convert type");
893  }
894  Value outputTensor = rewriter.create<tensor::EmptyOp>(
895  loc, targetShape, resultType.getElementType());
896 
897  // Create affine maps. Input affine maps broadcast static dimensions of size
898  // 1. The output affine map is an identity map.
899  //
900  auto rank = resultType.getRank();
901  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
902  auto shape = cast<ShapedType>(operand.getType()).getShape();
903  SmallVector<AffineExpr> affineExprs;
904  for (auto it : llvm::enumerate(shape)) {
905  // Prefer producting identity maps whenever possible (i.e. no broadcasting
906  // needed) because some transforms (like reshape folding)
907  // do not support affine constant exprs.
908  bool requiresBroadcast =
909  (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
910  auto affineExpr = requiresBroadcast
911  ? rewriter.getAffineConstantExpr(0)
912  : rewriter.getAffineDimExpr(it.index());
913  affineExprs.push_back(affineExpr);
914  }
915  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
916  });
917  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
918 
919  // Emit 'linalg.generic' op
920  bool encounteredError = false;
921  auto linalgOp = rewriter.create<linalg::GenericOp>(
922  loc, outputTensor.getType(), operands, outputTensor, affineMaps,
924  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
926  operation, blockArgs.take_front(operation->getNumOperands()),
927  {resultType.getElementType()}, rewriter);
928  if (!opResult) {
929  encounteredError = true;
930  return;
931  }
932  opBuilder.create<linalg::YieldOp>(loc, opResult);
933  });
934  if (encounteredError)
935  return rewriter.notifyMatchFailure(
936  operation, "unable to create linalg.generic body for elementwise op");
937 
938  // Cast 'linalg.generic' result into original result type if needed
939  auto castResult = rewriter.createOrFold<tensor::CastOp>(
940  loc, resultType, linalgOp->getResult(0));
941  rewriter.replaceOp(operation, castResult);
942  return success();
943 }
944 
945 static LogicalResult
947  ConversionPatternRewriter &rewriter,
948  const TypeConverter &converter) {
949 
950  // Collect op properties
951  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
952  assert(operation->getNumOperands() >= 1 &&
953  "elementwise op expects at least 1 operand");
954  if (!operandsAndResultsRanked(operation))
955  return rewriter.notifyMatchFailure(operation,
956  "Unranked tensors not supported");
957 
958  // Lower operation
959  IndexPool indexPool;
960  auto loc = operation->getLoc();
961  auto rank =
962  cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
963  // For the mul op we need to avoid expanding the rank of the optional shift
964  // input.
965  auto operandsToExpand =
966  isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
967 
968  auto expandedOperands =
969  expandInputRanks(rewriter, loc, operandsToExpand, rank);
970  auto [targetShape, masterOperands] =
971  computeTargetShape(rewriter, loc, indexPool, expandedOperands);
972  auto broadcastOperands = broadcastDynamicDimensions(
973  rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
974  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
975  targetShape, converter);
976 }
977 
978 // Returns the constant initial value for a given reduction operation. The
979 // attribute type varies depending on the element type required.
980 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
981  PatternRewriter &rewriter) {
982  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
983  return rewriter.getFloatAttr(elementTy, 0.0);
984 
985  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
986  return rewriter.getIntegerAttr(elementTy, 0);
987 
988  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
989  return rewriter.getFloatAttr(elementTy, 1.0);
990 
991  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
992  return rewriter.getIntegerAttr(elementTy, 1);
993 
994  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
995  return rewriter.getFloatAttr(
996  elementTy, APFloat::getLargest(
997  cast<FloatType>(elementTy).getFloatSemantics(), false));
998 
999  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1000  return rewriter.getIntegerAttr(
1001  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
1002 
1003  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1004  return rewriter.getFloatAttr(
1005  elementTy, APFloat::getLargest(
1006  cast<FloatType>(elementTy).getFloatSemantics(), true));
1007 
1008  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1009  return rewriter.getIntegerAttr(
1010  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1011 
1012  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1013  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1014 
1015  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1016  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1017 
1018  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1019  return rewriter.getFloatAttr(
1020  elementTy, APFloat::getLargest(
1021  cast<FloatType>(elementTy).getFloatSemantics(), true));
1022 
1023  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1024  return rewriter.getIntegerAttr(
1025  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1026 
1027  return {};
1028 }
1029 
1030 // Creates the body calculation for a reduction. The operations vary depending
1031 // on the input type.
1033  ValueRange args,
1034  Type elementTy,
1035  PatternRewriter &rewriter) {
1036  Location loc = op->getLoc();
1037  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1038  return rewriter.create<arith::AddFOp>(loc, args);
1039  }
1040 
1041  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1042  return rewriter.create<arith::AddIOp>(loc, args);
1043  }
1044 
1045  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
1046  return rewriter.create<arith::MulFOp>(loc, args);
1047  }
1048 
1049  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
1050  return rewriter.create<arith::MulIOp>(loc, args);
1051  }
1052 
1053  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1054  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1055  }
1056 
1057  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1058  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1059  }
1060 
1061  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1062  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1063  }
1064 
1065  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1066  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1067  }
1068 
1069  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1070  return rewriter.create<arith::AndIOp>(loc, args);
1071 
1072  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1073  return rewriter.create<arith::OrIOp>(loc, args);
1074 
1075  return {};
1076 }
1077 
1078 // Performs the match and rewrite for reduction operations. This includes
1079 // declaring a correctly sized initial value, and the linalg.generic operation
1080 // that reduces across the specified axis.
1081 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1082  PatternRewriter &rewriter) {
1083  auto loc = op->getLoc();
1084  auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
1085  auto resultTy = cast<ShapedType>(op->getResult(0).getType());
1086  auto elementTy = resultTy.getElementType();
1087  Value input = op->getOperand(0);
1088 
1089  SmallVector<int64_t> reduceShape;
1090  SmallVector<Value> dynDims;
1091  for (unsigned i = 0; i < inputTy.getRank(); i++) {
1092  if (axis != i) {
1093  reduceShape.push_back(inputTy.getDimSize(i));
1094  if (inputTy.isDynamicDim(i))
1095  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1096  }
1097  }
1098 
1099  // First fill the output buffer with the init value.
1100  auto emptyTensor =
1101  rewriter
1102  .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1103  dynDims)
1104  .getResult();
1105 
1106  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1107  if (!fillValueAttr)
1108  return rewriter.notifyMatchFailure(
1109  op, "No initial value found for reduction operation");
1110 
1111  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1112  auto filledTensor = rewriter
1113  .create<linalg::FillOp>(loc, ValueRange{fillValue},
1114  ValueRange{emptyTensor})
1115  .result();
1116 
1117  bool didEncounterError = false;
1118  auto linalgOp = rewriter.create<linalg::ReduceOp>(
1119  loc, input, filledTensor, axis,
1120  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1122  op, blockArgs, elementTy, rewriter);
1123  if (result)
1124  didEncounterError = true;
1125 
1126  nestedBuilder.create<linalg::YieldOp>(loc, result);
1127  });
1128 
1129  if (!didEncounterError)
1130  return rewriter.notifyMatchFailure(
1131  op, "unable to create linalg.generic body for reduce op");
1132 
1133  SmallVector<ReassociationExprs, 4> reassociationMap;
1134  uint64_t expandInputRank =
1135  cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1136  reassociationMap.resize(expandInputRank);
1137 
1138  for (uint64_t i = 0; i < expandInputRank; i++) {
1139  int32_t dimToPush = i > axis ? i + 1 : i;
1140  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1141  }
1142 
1143  if (expandInputRank != 0) {
1144  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1145  reassociationMap[expandedDim].push_back(
1146  rewriter.getAffineDimExpr(expandedDim + 1));
1147  }
1148 
1149  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1150  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1151  // not have access to such information. This matters when handling dynamically
1152  // sized tensors.
1153  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1154  op, resultTy, linalgOp.getResults()[0], reassociationMap);
1155  return success();
1156 }
1157 
1158 namespace {
1159 
1160 template <typename SrcOp>
1161 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1162 public:
1165 
1166  LogicalResult
1167  matchAndRewrite(SrcOp op, OpAdaptor operands,
1168  ConversionPatternRewriter &rewriter) const final {
1170  op, operands.getOperands(), rewriter, *this->getTypeConverter());
1171  }
1172 };
1173 
1174 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1175 public:
1177 
1178  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1179  PatternRewriter &rewriter) const final {
1180  auto loc = op.getLoc();
1181  auto input = op.getInput();
1182  auto inputTy = cast<ShapedType>(op.getInput().getType());
1183  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1184  unsigned rank = inputTy.getRank();
1185 
1186  // This is an illegal configuration. terminate and log an error
1187  if (op.getDoubleRound() && !op.getScale32())
1188  return rewriter.notifyMatchFailure(
1189  op, "tosa.rescale requires scale32 for double_round to be true");
1190 
1191  if (!isa<IntegerType>(inputTy.getElementType()))
1192  return rewriter.notifyMatchFailure(op, "only support integer type");
1193 
1194  SmallVector<Value> dynDims;
1195  for (int i = 0; i < outputTy.getRank(); i++) {
1196  if (outputTy.isDynamicDim(i)) {
1197  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1198  }
1199  }
1200 
1201  // The shift and multiplier values.
1202  SmallVector<int32_t> multiplierValues(op.getMultiplier());
1203  SmallVector<int8_t> shiftValues(op.getShift());
1204 
1205  // If we shift by more than the bitwidth, this just sets to 0.
1206  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1207  if (shiftValues[i] > 63) {
1208  shiftValues[i] = 0;
1209  multiplierValues[i] = 0;
1210  }
1211  }
1212 
1213  // Double round only occurs if shift is greater than 31, check that this
1214  // is ever true.
1215  bool doubleRound =
1216  op.getDoubleRound() &&
1217  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1218 
1219  SmallVector<AffineMap> indexingMaps = {
1220  rewriter.getMultiDimIdentityMap(rank)};
1221  SmallVector<Value, 4> genericInputs = {input};
1222 
1223  // If we are rescaling per-channel then we need to store the multiplier
1224  // values in a buffer.
1225  Value multiplierConstant;
1226  int64_t multiplierArg = 0;
1227  if (multiplierValues.size() == 1) {
1228  multiplierConstant = rewriter.create<arith::ConstantOp>(
1229  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1230  } else {
1231  SmallVector<AffineExpr, 2> multiplierExprs{
1232  rewriter.getAffineDimExpr(rank - 1)};
1233  auto multiplierType =
1234  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1235  rewriter.getI32Type());
1236  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1237  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1238 
1239  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1240  /*symbolCount=*/0, multiplierExprs,
1241  rewriter.getContext()));
1242 
1243  multiplierArg = indexingMaps.size() - 1;
1244  }
1245 
1246  // If we are rescaling per-channel then we need to store the shift
1247  // values in a buffer.
1248  Value shiftConstant;
1249  int64_t shiftArg = 0;
1250  if (shiftValues.size() == 1) {
1251  shiftConstant = rewriter.create<arith::ConstantOp>(
1252  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1253  } else {
1254  SmallVector<AffineExpr, 2> shiftExprs = {
1255  rewriter.getAffineDimExpr(rank - 1)};
1256  auto shiftType =
1257  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1258  rewriter.getIntegerType(8));
1259  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1260  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1261  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1262  /*symbolCount=*/0, shiftExprs,
1263  rewriter.getContext()));
1264  shiftArg = indexingMaps.size() - 1;
1265  }
1266 
1267  // Indexing maps for output values.
1268  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1269 
1270  // Construct the indexing maps needed for linalg.generic ops.
1271  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1272  loc, outputTy.getShape(), outputTy.getElementType(),
1273  ArrayRef<Value>({dynDims}));
1274 
1275  auto linalgOp = rewriter.create<linalg::GenericOp>(
1276  loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1277  getNParallelLoopsAttrs(rank),
1278  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1279  ValueRange blockArgs) {
1280  Value value = blockArgs[0];
1281  Type valueTy = value.getType();
1282 
1283  // For now we do all of our math in 64-bit. This is not optimal but
1284  // should be correct for now, consider computing correct bit depth
1285  // later.
1286  int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1287 
1288  auto inputZp = createConstFromIntAttribute<int32_t>(
1289  op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1290  nestedBuilder);
1291  auto outputZp = createConstFromIntAttribute<int32_t>(
1292  op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1293 
1294  Value multiplier = multiplierConstant ? multiplierConstant
1295  : blockArgs[multiplierArg];
1296  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1297 
1298  if (valueTy.getIntOrFloatBitWidth() < 32) {
1299  if (op.getInputUnsigned()) {
1300  value = nestedBuilder.create<arith::ExtUIOp>(
1301  nestedLoc, nestedBuilder.getI32Type(), value);
1302  } else {
1303  value = nestedBuilder.create<arith::ExtSIOp>(
1304  nestedLoc, nestedBuilder.getI32Type(), value);
1305  }
1306  }
1307 
1308  value =
1309  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1310 
1311  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1312  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1313  nestedBuilder.getBoolAttr(doubleRound));
1314 
1315  // Move to the new zero-point.
1316  value =
1317  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1318 
1319  // Saturate to the output size.
1320  IntegerType outIntType =
1321  cast<IntegerType>(blockArgs.back().getType());
1322  unsigned outBitWidth = outIntType.getWidth();
1323 
1324  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1325  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1326 
1327  // Unsigned integers have a difference output value.
1328  if (op.getOutputUnsigned()) {
1329  intMin = 0;
1330  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1331  }
1332 
1333  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1334  loc, nestedBuilder.getI32IntegerAttr(intMin));
1335  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1336  loc, nestedBuilder.getI32IntegerAttr(intMax));
1337 
1338  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1339  nestedBuilder, /*isUnsigned=*/false);
1340 
1341  if (outIntType.getWidth() < 32) {
1342  value = nestedBuilder.create<arith::TruncIOp>(
1343  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1344  value);
1345  }
1346 
1347  nestedBuilder.create<linalg::YieldOp>(loc, value);
1348  });
1349 
1350  rewriter.replaceOp(op, linalgOp->getResults());
1351  return success();
1352  }
1353 };
1354 
1355 // Handle the resize case where the input is a 1x1 image. This case
1356 // can entirely avoiding having extract operations which target much
1357 // more difficult to optimize away.
1358 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1359 public:
1361 
1362  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1363  PatternRewriter &rewriter) const final {
1364  Location loc = op.getLoc();
1365  ImplicitLocOpBuilder builder(loc, rewriter);
1366  auto input = op.getInput();
1367  auto inputTy = cast<RankedTensorType>(input.getType());
1368  auto resultTy = cast<RankedTensorType>(op.getType());
1369  const bool isBilinear = op.getMode() == "BILINEAR";
1370 
1371  auto inputH = inputTy.getDimSize(1);
1372  auto inputW = inputTy.getDimSize(2);
1373  auto outputH = resultTy.getDimSize(1);
1374  auto outputW = resultTy.getDimSize(2);
1375 
1376  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1377  return rewriter.notifyMatchFailure(
1378  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1379 
1380  // TODO(suderman): These string values should be declared the TOSA dialect.
1381  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1382  return rewriter.notifyMatchFailure(
1383  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1384 
1385  if (inputTy == resultTy) {
1386  rewriter.replaceOp(op, input);
1387  return success();
1388  }
1389 
1390  ArrayRef<int64_t> scale = op.getScale();
1391 
1392  // Collapse the unit width and height away.
1393  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1394  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1395  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1396  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1397  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1398 
1399  auto collapseTy =
1400  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1401  inputTy.getElementType());
1402  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1403  reassociationMap);
1404 
1405  // Get any dynamic shapes that appear in the input format.
1406  llvm::SmallVector<Value> outputDynSize;
1407  if (inputTy.isDynamicDim(0))
1408  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1409  if (inputTy.isDynamicDim(3))
1410  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1411 
1412  // Generate the elementwise operation for casting scaling the input value.
1413  auto genericTy = collapseTy.clone(resultTy.getElementType());
1414  Value empty = builder.create<tensor::EmptyOp>(
1415  genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1416  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1417  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1418  utils::IteratorType::parallel);
1419 
1420  auto generic = builder.create<linalg::GenericOp>(
1421  genericTy, ValueRange{collapse}, ValueRange{empty},
1422  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1423  [=](OpBuilder &b, Location loc, ValueRange args) {
1424  Value value = args[0];
1425  // This is the quantized case.
1426  if (inputTy.getElementType() != resultTy.getElementType()) {
1427  value =
1428  b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1429 
1430  if (isBilinear && scale[0] != 0) {
1431  Value scaleY = b.create<arith::ConstantOp>(
1432  loc, b.getI32IntegerAttr(scale[0]));
1433  value = b.create<arith::MulIOp>(loc, value, scaleY);
1434  }
1435 
1436  if (isBilinear && scale[2] != 0) {
1437  Value scaleX = b.create<arith::ConstantOp>(
1438  loc, b.getI32IntegerAttr(scale[2]));
1439  value = b.create<arith::MulIOp>(loc, value, scaleX);
1440  }
1441  }
1442 
1443  b.create<linalg::YieldOp>(loc, value);
1444  });
1445 
1446  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1447  op, resultTy, generic.getResults()[0], reassociationMap);
1448  return success();
1449  }
1450 };
1451 
1452 // TOSA resize with width or height of 1 may be broadcasted to a wider
1453 // dimension. This is done by materializing a new tosa.resize without
1454 // the broadcasting behavior, and an explicit broadcast afterwards.
1455 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1456 public:
1458 
1459  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1460  PatternRewriter &rewriter) const final {
1461  Location loc = op.getLoc();
1462  ImplicitLocOpBuilder builder(loc, rewriter);
1463  auto input = op.getInput();
1464  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1465  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1466 
1467  if (!inputTy || !resultTy)
1468  return rewriter.notifyMatchFailure(op,
1469  "requires ranked input/output types");
1470 
1471  auto batch = inputTy.getDimSize(0);
1472  auto channels = inputTy.getDimSize(3);
1473  auto inputH = inputTy.getDimSize(1);
1474  auto inputW = inputTy.getDimSize(2);
1475  auto outputH = resultTy.getDimSize(1);
1476  auto outputW = resultTy.getDimSize(2);
1477 
1478  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1479  return rewriter.notifyMatchFailure(
1480  op, "tosa.resize has no broadcasting behavior");
1481 
1482  // For any dimension that is broadcastable we generate a width of 1
1483  // on the output.
1484  llvm::SmallVector<int64_t> resizeShape;
1485  resizeShape.push_back(batch);
1486  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1487  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1488  resizeShape.push_back(channels);
1489 
1490  auto resizeTy = resultTy.clone(resizeShape);
1491  auto resize =
1492  builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1493 
1494  // Collapse an unit result dims.
1495  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1496  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1497  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1498  if (inputH != 1)
1499  reassociationMap.push_back({});
1500  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1501  if (inputW != 1)
1502  reassociationMap.push_back({});
1503  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1504 
1505  llvm::SmallVector<int64_t> collapseShape = {batch};
1506  if (inputH != 1)
1507  collapseShape.push_back(outputH);
1508  if (inputW != 1)
1509  collapseShape.push_back(outputW);
1510  collapseShape.push_back(channels);
1511 
1512  auto collapseTy = resultTy.clone(collapseShape);
1513  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1514  reassociationMap);
1515 
1516  // Broadcast the collapsed shape to the output result.
1517  llvm::SmallVector<Value> outputDynSize;
1518  if (inputTy.isDynamicDim(0))
1519  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1520  if (inputTy.isDynamicDim(3))
1521  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1522 
1523  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1524  utils::IteratorType::parallel);
1525  Value empty = builder.create<tensor::EmptyOp>(
1526  resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1527 
1528  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1529  if (inputH != 1)
1530  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1531  if (inputW != 1)
1532  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1533  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1534 
1535  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1536  inputExprs, rewriter.getContext());
1537 
1538  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1539  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1540  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1541  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1542  [=](OpBuilder &b, Location loc, ValueRange args) {
1543  Value value = args[0];
1544  b.create<linalg::YieldOp>(loc, value);
1545  });
1546 
1547  return success();
1548  }
1549 };
1550 
1551 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1552 public:
1554 
1555  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1556  PatternRewriter &rewriter) const final {
1557  Location loc = op.getLoc();
1558  ImplicitLocOpBuilder b(loc, rewriter);
1559  auto input = op.getInput();
1560  auto inputTy = cast<ShapedType>(input.getType());
1561  auto resultTy = cast<ShapedType>(op.getType());
1562  auto resultETy = resultTy.getElementType();
1563 
1564  bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1565  auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1566 
1567  auto imageH = inputTy.getShape()[1];
1568  auto imageW = inputTy.getShape()[2];
1569 
1570  auto dynamicDimsOr =
1571  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1572  if (!dynamicDimsOr.has_value())
1573  return rewriter.notifyMatchFailure(
1574  op, "unable to get dynamic dimensions of tosa.resize");
1575 
1576  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1577  return rewriter.notifyMatchFailure(
1578  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1579 
1580  SmallVector<AffineMap, 2> affineMaps = {
1581  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1582  auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1583  *dynamicDimsOr);
1584  auto genericOp = b.create<linalg::GenericOp>(
1585  resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1586  getNParallelLoopsAttrs(resultTy.getRank()));
1587  Value resize = genericOp.getResult(0);
1588 
1589  {
1590  OpBuilder::InsertionGuard regionGuard(b);
1591  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1592  TypeRange({resultETy}), loc);
1593  Value batch = b.create<linalg::IndexOp>(0);
1594  Value y = b.create<linalg::IndexOp>(1);
1595  Value x = b.create<linalg::IndexOp>(2);
1596  Value channel = b.create<linalg::IndexOp>(3);
1597 
1598  Value zeroI32 =
1599  b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1600  Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1601  Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1602  Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1603 
1604  Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1605  Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1606 
1607  ArrayRef<int64_t> offset = op.getOffset();
1608  ArrayRef<int64_t> border = op.getBorder();
1609  ArrayRef<int64_t> scale = op.getScale();
1610 
1611  Value yScaleN, yScaleD, xScaleN, xScaleD;
1612  yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1613  yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1614  xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1615  xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1616 
1617  Value yOffset, xOffset, yBorder, xBorder;
1618  yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1619  xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1620  yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1621  xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1622 
1623  // Compute the ix and dx values for both the X and Y dimensions.
1624  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1625  Value scaleN, Value scaleD, Value offset,
1626  int size, ImplicitLocOpBuilder &b) {
1627  if (size == 1) {
1628  index = zeroI32;
1629  delta = zeroFp;
1630  return;
1631  }
1632  // x = x * scale_d + offset;
1633  // ix = floor(x / scale_n)
1634  Value val = b.create<arith::MulIOp>(in, scaleD);
1635  val = b.create<arith::AddIOp>(val, offset);
1636  index = b.create<arith::FloorDivSIOp>(val, scaleN);
1637 
1638  // rx = x % scale_n
1639  // dx = rx / scale_n
1640  Value r = b.create<arith::RemSIOp>(val, scaleN);
1641  Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1642  Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1643  delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1644  };
1645 
1646  // Compute the ix and dx values for the X and Y dimensions - int case.
1647  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1648  Value scaleN, Value scaleD, Value offset,
1649  int size, ImplicitLocOpBuilder &b) {
1650  if (size == 1) {
1651  index = zeroI32;
1652  delta = zeroI32;
1653  return;
1654  }
1655  // x = x * scale_d + offset;
1656  // ix = floor(x / scale_n)
1657  // dx = x - ix * scale_n;
1658  Value val = b.create<arith::MulIOp>(in, scaleD);
1659  val = b.create<arith::AddIOp>(val, offset);
1660  index = b.create<arith::DivSIOp>(val, scaleN);
1661  delta = b.create<arith::MulIOp>(index, scaleN);
1662  delta = b.create<arith::SubIOp>(val, delta);
1663  };
1664 
1665  Value ix, iy, dx, dy;
1666  if (floatingPointMode) {
1667  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1668  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1669  } else {
1670  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1671  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1672  }
1673 
1674  if (op.getMode() == "NEAREST_NEIGHBOR") {
1675  auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1676 
1677  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1678  Value max, int size,
1679  ImplicitLocOpBuilder &b) -> Value {
1680  if (size == 1) {
1681  return b.create<arith::ConstantIndexOp>(0);
1682  }
1683 
1684  Value pred;
1685  if (floatingPointMode) {
1686  auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1687  pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1688  } else {
1689  Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1690  pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1691  dvalDouble, scale);
1692  }
1693 
1694  auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1695  val = b.create<arith::AddIOp>(val, offset);
1696  val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
1697  return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1698  };
1699 
1700  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1701  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1702 
1703  Value result = b.create<tensor::ExtractOp>(
1704  input, ValueRange{batch, iy, ix, channel});
1705 
1706  b.create<linalg::YieldOp>(result);
1707  } else {
1708  // The mode here must be BILINEAR.
1709  assert(op.getMode() == "BILINEAR");
1710 
1711  auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1712 
1713  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1715  val0 = in;
1716  val1 = b.create<arith::AddIOp>(val0, oneVal);
1717  val0 =
1718  clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1719  val1 =
1720  clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
1721  val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1722  val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1723  };
1724 
1725  // Linalg equivalent to the section below:
1726  // int16_t iy0 = apply_max(iy, 0);
1727  // int16_t iy1 = apply_min(iy + 1, IH - 1);
1728  // int16_t ix0 = apply_max(ix, 0);
1729  // int16_t ix1 = apply_min(ix + 1, IW - 1);
1730  Value x0, x1, y0, y1;
1731  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1732  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1733 
1734  Value y0x0 = b.create<tensor::ExtractOp>(
1735  input, ValueRange{batch, y0, x0, channel});
1736  Value y0x1 = b.create<tensor::ExtractOp>(
1737  input, ValueRange{batch, y0, x1, channel});
1738  Value y1x0 = b.create<tensor::ExtractOp>(
1739  input, ValueRange{batch, y1, x0, channel});
1740  Value y1x1 = b.create<tensor::ExtractOp>(
1741  input, ValueRange{batch, y1, x1, channel});
1742 
1743  if (floatingPointMode) {
1744  auto oneVal =
1745  b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1746  auto interpolate = [&](Value val0, Value val1, Value delta,
1747  int inputSize,
1748  ImplicitLocOpBuilder &b) -> Value {
1749  if (inputSize == 1)
1750  return val0;
1751  Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1752  Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1753  Value mul1 = b.create<arith::MulFOp>(val1, delta);
1754  return b.create<arith::AddFOp>(mul0, mul1);
1755  };
1756 
1757  // Linalg equivalent to the section below:
1758  // topAcc = v00 * (unit_x - dx);
1759  // topAcc += v01 * dx;
1760  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1761 
1762  // Linalg equivalent to the section below:
1763  // bottomAcc = v10 * (unit_x - dx);
1764  // bottomAcc += v11 * dx;
1765  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1766 
1767  // Linalg equivalent to the section below:
1768  // result = topAcc * (unit_y - dy) + bottomAcc * dy
1769  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1770  b.create<linalg::YieldOp>(result);
1771  } else {
1772  // Perform in quantized space.
1773  y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1774  y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1775  y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1776  y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1777 
1778  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1779  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1780  dx = b.create<arith::ExtSIOp>(resultETy, dx);
1781  dy = b.create<arith::ExtSIOp>(resultETy, dy);
1782  }
1783 
1784  Value yScaleNExt = yScaleN;
1785  Value xScaleNExt = xScaleN;
1786 
1787  const int64_t scaleBitwidth =
1788  xScaleN.getType().getIntOrFloatBitWidth();
1789  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1790  yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1791  xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1792  }
1793 
1794  auto interpolate = [](Value val0, Value val1, Value weight1,
1795  Value scale, int inputSize,
1796  ImplicitLocOpBuilder &b) -> Value {
1797  if (inputSize == 1)
1798  return b.create<arith::MulIOp>(val0, scale);
1799  Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1800  Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1801  Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1802  return b.create<arith::AddIOp>(mul0, mul1);
1803  };
1804 
1805  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1806  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1807  Value result =
1808  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1809  b.create<linalg::YieldOp>(result);
1810  }
1811  }
1812  }
1813 
1814  rewriter.replaceOp(op, resize);
1815  return success();
1816  }
1817 };
1818 
1819 // At the codegen level any identity operations should be removed. Any cases
1820 // where identity is load-bearing (e.g. cross device computation) should be
1821 // handled before lowering to codegen.
1822 template <typename SrcOp>
1823 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1824 public:
1826 
1827  LogicalResult matchAndRewrite(SrcOp op,
1828  PatternRewriter &rewriter) const final {
1829  rewriter.replaceOp(op, op.getOperation()->getOperands());
1830  return success();
1831  }
1832 };
1833 
1834 template <typename SrcOp>
1835 class ReduceConverter : public OpRewritePattern<SrcOp> {
1836 public:
1838 
1839  LogicalResult matchAndRewrite(SrcOp reduceOp,
1840  PatternRewriter &rewriter) const final {
1841  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1842  }
1843 };
1844 
1845 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1846 public:
1848 
1849  LogicalResult matchAndRewrite(tosa::ReverseOp op,
1850  PatternRewriter &rewriter) const final {
1851  auto loc = op.getLoc();
1852  Value input = op.getInput1();
1853  auto inputTy = cast<ShapedType>(input.getType());
1854  auto resultTy = cast<ShapedType>(op.getType());
1855  auto axis = op.getAxis();
1856 
1857  SmallVector<Value> dynDims;
1858  for (int i = 0; i < inputTy.getRank(); i++) {
1859  if (inputTy.isDynamicDim(i)) {
1860  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1861  }
1862  }
1863 
1864  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1865 
1866  // First fill the output buffer with the init value.
1867  auto emptyTensor = rewriter
1868  .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1869  inputTy.getElementType(),
1870  ArrayRef<Value>({dynDims}))
1871  .getResult();
1872  SmallVector<AffineMap, 2> affineMaps = {
1873  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1874 
1875  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1876  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
1877  getNParallelLoopsAttrs(resultTy.getRank()),
1878  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1879  llvm::SmallVector<Value> indices;
1880  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1881  Value index =
1882  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1883  if (i == axis) {
1884  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1885  auto sizeMinusOne =
1886  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1887  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1888  index);
1889  }
1890 
1891  indices.push_back(index);
1892  }
1893 
1894  auto extract = nestedBuilder.create<tensor::ExtractOp>(
1895  nestedLoc, input, indices);
1896  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1897  extract.getResult());
1898  });
1899  return success();
1900  }
1901 };
1902 
1903 // This converter translate a tile operation to a reshape, broadcast, reshape.
1904 // The first reshape minimally expands each tiled dimension to include a
1905 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1906 // multiple.
1907 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1909 
1910  LogicalResult
1911  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1912  ConversionPatternRewriter &rewriter) const override {
1913  auto loc = op.getLoc();
1914  auto input = op.getInput1();
1915  auto inputTy = cast<ShapedType>(input.getType());
1916  auto inputShape = inputTy.getShape();
1917  auto resultTy = cast<ShapedType>(op.getType());
1918  auto elementTy = inputTy.getElementType();
1919  int64_t rank = inputTy.getRank();
1920 
1921  SmallVector<int64_t> multiples;
1922  if (failed(op.getConstantMultiples(multiples)))
1923  return failure();
1924 
1925  // Broadcast the newly added dimensions to their appropriate multiple.
1926  SmallVector<int64_t, 2> genericShape;
1927  for (int i = 0; i < rank; i++) {
1928  int64_t dim = multiples[i];
1929  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1930  genericShape.push_back(inputShape[i]);
1931  }
1932 
1933  SmallVector<Value> dynDims;
1934  for (int i = 0; i < inputTy.getRank(); i++) {
1935  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1936  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1937  }
1938  }
1939 
1940  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1941  op.getLoc(), genericShape, elementTy, dynDims);
1942 
1943  // We needs to map the input shape to the non-broadcasted dimensions.
1944  SmallVector<AffineExpr, 4> dimExprs;
1945  dimExprs.reserve(rank);
1946  for (unsigned i = 0; i < rank; ++i)
1947  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1948 
1949  auto readAffineMap =
1950  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1951  rewriter.getContext());
1952 
1953  SmallVector<AffineMap, 2> affineMaps = {
1954  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1955 
1956  auto genericOp = rewriter.create<linalg::GenericOp>(
1957  loc, RankedTensorType::get(genericShape, elementTy), input,
1958  ValueRange{emptyTensor}, affineMaps,
1959  getNParallelLoopsAttrs(genericShape.size()),
1960  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1961  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1962  });
1963 
1964  auto shapeValue = getTosaConstShape(
1965  rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
1966  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1967  op, resultTy, genericOp.getResult(0), shapeValue);
1968  return success();
1969  }
1970 };
1971 
1972 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1973 // op, producing two output buffers.
1974 //
1975 // The first output buffer contains the index of the found maximum value. It is
1976 // initialized to 0 and is resulting integer type.
1977 //
1978 // The second output buffer contains the maximum value found. It is initialized
1979 // to the minimum representable value of the input element type. After being
1980 // populated by indexed_generic, this buffer is disgarded as only the index is
1981 // requested.
1982 //
1983 // The indexed_generic op updates both the maximum value and index if the
1984 // current value exceeds the running max.
1985 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1986 public:
1988 
1989  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1990  PatternRewriter &rewriter) const final {
1991  auto loc = argmaxOp.getLoc();
1992  Value input = argmaxOp.getInput();
1993  auto inputTy = cast<ShapedType>(input.getType());
1994  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1995  auto inElementTy = inputTy.getElementType();
1996  auto outElementTy = resultTy.getElementType();
1997  int axis = argmaxOp.getAxis();
1998  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1999 
2000  if (!isa<IntegerType>(outElementTy))
2001  return rewriter.notifyMatchFailure(
2002  argmaxOp,
2003  "tosa.arg_max to linalg.* requires integer-like result type");
2004 
2005  SmallVector<Value> dynDims;
2006  for (int i = 0; i < inputTy.getRank(); i++) {
2007  if (inputTy.isDynamicDim(i) && i != axis) {
2008  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2009  }
2010  }
2011 
2012  // First fill the output buffer for the index.
2013  auto emptyTensorIdx = rewriter
2014  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2015  outElementTy, dynDims)
2016  .getResult();
2017  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
2018  loc, rewriter.getIntegerAttr(outElementTy, 0));
2019  auto filledTensorIdx =
2020  rewriter
2021  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
2022  ValueRange{emptyTensorIdx})
2023  .result();
2024 
2025  // Second fill the output buffer for the running max.
2026  auto emptyTensorMax = rewriter
2027  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2028  inElementTy, dynDims)
2029  .getResult();
2030  auto fillValueMaxAttr =
2031  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2032 
2033  if (!fillValueMaxAttr)
2034  return rewriter.notifyMatchFailure(
2035  argmaxOp, "unsupported tosa.argmax element type");
2036 
2037  auto fillValueMax =
2038  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2039  auto filledTensorMax =
2040  rewriter
2041  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2042  ValueRange{emptyTensorMax})
2043  .result();
2044 
2045  // We need to reduce along the arg-max axis, with parallel operations along
2046  // the rest.
2048  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2049  iteratorTypes[axis] = utils::IteratorType::reduction;
2050 
2051  SmallVector<AffineExpr, 2> srcExprs;
2052  SmallVector<AffineExpr, 2> dstExprs;
2053  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2054  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2055  if (axis != i)
2056  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2057  }
2058 
2059  bool didEncounterError = false;
2060  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2061  rewriter.getContext());
2062  auto linalgOp = rewriter.create<linalg::GenericOp>(
2063  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2064  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2065  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2066  ValueRange blockArgs) {
2067  auto newValue = blockArgs[0];
2068  auto oldIndex = blockArgs[1];
2069  auto oldValue = blockArgs[2];
2070 
2071  Value newIndex = rewriter.create<arith::IndexCastOp>(
2072  nestedLoc, oldIndex.getType(),
2073  rewriter.create<linalg::IndexOp>(loc, axis));
2074 
2075  Value predicate;
2076  if (isa<FloatType>(inElementTy)) {
2077  predicate = rewriter.create<arith::CmpFOp>(
2078  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2079  } else if (isa<IntegerType>(inElementTy)) {
2080  predicate = rewriter.create<arith::CmpIOp>(
2081  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2082  } else {
2083  didEncounterError = true;
2084  return;
2085  }
2086 
2087  auto resultMax = rewriter.create<arith::SelectOp>(
2088  nestedLoc, predicate, newValue, oldValue);
2089  auto resultIndex = rewriter.create<arith::SelectOp>(
2090  nestedLoc, predicate, newIndex, oldIndex);
2091  nestedBuilder.create<linalg::YieldOp>(
2092  nestedLoc, ValueRange({resultIndex, resultMax}));
2093  });
2094 
2095  if (didEncounterError)
2096  return rewriter.notifyMatchFailure(
2097  argmaxOp, "unsupported tosa.argmax element type");
2098 
2099  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2100  return success();
2101  }
2102 };
2103 
2104 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2105 public:
2107  LogicalResult
2108  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2109  ConversionPatternRewriter &rewriter) const final {
2110  auto input = adaptor.getOperands()[0];
2111  auto indices = adaptor.getOperands()[1];
2112 
2113  auto valuesTy =
2114  dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2115  auto resultTy = cast<ShapedType>(op.getType());
2116 
2117  if (!valuesTy)
2118  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2119 
2120  auto dynamicDims = inferDynamicDimsForGather(
2121  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2122 
2123  auto resultElementTy = resultTy.getElementType();
2124 
2125  auto loc = op.getLoc();
2126  auto emptyTensor =
2127  rewriter
2128  .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2129  dynamicDims)
2130  .getResult();
2131 
2132  SmallVector<AffineMap, 2> affineMaps = {
2134  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2135  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2136  rewriter.getContext()),
2137  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2138 
2139  auto genericOp = rewriter.create<linalg::GenericOp>(
2140  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2141  ValueRange{emptyTensor}, affineMaps,
2142  getNParallelLoopsAttrs(resultTy.getRank()),
2143  [&](OpBuilder &b, Location loc, ValueRange args) {
2144  auto indexValue = args[0];
2145  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2146  Value index1 = rewriter.create<arith::IndexCastOp>(
2147  loc, rewriter.getIndexType(), indexValue);
2148  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2149  Value extract = rewriter.create<tensor::ExtractOp>(
2150  loc, input, ValueRange{index0, index1, index2});
2151  rewriter.create<linalg::YieldOp>(loc, extract);
2152  });
2153  rewriter.replaceOp(op, genericOp.getResult(0));
2154  return success();
2155  }
2156 
2157  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2158  Location loc,
2159  Value values,
2160  Value indices) {
2161  llvm::SmallVector<Value> results;
2162 
2163  auto addDynamicDimension = [&](Value source, int64_t dim) {
2164  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2165  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2166  results.push_back(dimValue);
2167  };
2168 
2169  addDynamicDimension(values, 0);
2170  addDynamicDimension(indices, 1);
2171  addDynamicDimension(values, 2);
2172  return results;
2173  }
2174 };
2175 
2176 // Lowerings the TableOp to a series of gathers and numerica operations. This
2177 // includes interpolation between the high/low values. For the I8 varient, this
2178 // simplifies to a single gather operation.
2179 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2180 public:
2182 
2183  LogicalResult matchAndRewrite(tosa::TableOp op,
2184  PatternRewriter &rewriter) const final {
2185  auto loc = op.getLoc();
2186  Value input = op.getInput1();
2187  Value table = op.getTable();
2188  auto inputTy = cast<ShapedType>(input.getType());
2189  auto tableTy = cast<ShapedType>(table.getType());
2190  auto resultTy = cast<ShapedType>(op.getType());
2191 
2192  auto inputElementTy = inputTy.getElementType();
2193  auto tableElementTy = tableTy.getElementType();
2194  auto resultElementTy = resultTy.getElementType();
2195 
2196  SmallVector<Value> dynDims;
2197  for (int i = 0; i < resultTy.getRank(); ++i) {
2198  if (inputTy.isDynamicDim(i)) {
2199  dynDims.push_back(
2200  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2201  }
2202  }
2203 
2204  auto emptyTensor = rewriter
2205  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2206  resultElementTy, dynDims)
2207  .getResult();
2208 
2209  SmallVector<AffineMap, 2> affineMaps = {
2210  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2211  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2212 
2213  auto genericOp = rewriter.create<linalg::GenericOp>(
2214  loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2215  getNParallelLoopsAttrs(resultTy.getRank()));
2216  rewriter.replaceOp(op, genericOp.getResult(0));
2217 
2218  {
2219  OpBuilder::InsertionGuard regionGuard(rewriter);
2220  Block *block = rewriter.createBlock(
2221  &genericOp.getRegion(), genericOp.getRegion().end(),
2222  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2223 
2224  auto inputValue = block->getArgument(0);
2225  rewriter.setInsertionPointToStart(block);
2226  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2227  resultElementTy.isInteger(8)) {
2228  Value index = rewriter.create<arith::IndexCastOp>(
2229  loc, rewriter.getIndexType(), inputValue);
2230  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2231  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2232  index, offset);
2233  Value extract =
2234  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2235  rewriter.create<linalg::YieldOp>(loc, extract);
2236  return success();
2237  }
2238 
2239  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2240  resultElementTy.isInteger(32)) {
2241  Value extend = rewriter.create<arith::ExtSIOp>(
2242  loc, rewriter.getI32Type(), inputValue);
2243 
2244  auto offset = rewriter.create<arith::ConstantOp>(
2245  loc, rewriter.getI32IntegerAttr(32768));
2246  auto seven = rewriter.create<arith::ConstantOp>(
2247  loc, rewriter.getI32IntegerAttr(7));
2248  auto one = rewriter.create<arith::ConstantOp>(
2249  loc, rewriter.getI32IntegerAttr(1));
2250  auto b1111111 = rewriter.create<arith::ConstantOp>(
2251  loc, rewriter.getI32IntegerAttr(127));
2252 
2253  // Compute the index and fractional part from the input value:
2254  // value = value + 32768
2255  // index = value >> 7;
2256  // fraction = 0x01111111 & value
2257  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2258  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2259  Value fraction =
2260  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2261 
2262  // Extract the base and next values from the table.
2263  // base = (int32_t) table[index];
2264  // next = (int32_t) table[index + 1];
2265  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2266 
2267  index = rewriter.create<arith::IndexCastOp>(
2268  loc, rewriter.getIndexType(), index);
2269  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2270  loc, rewriter.getIndexType(), indexPlusOne);
2271 
2272  Value base =
2273  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2274  Value next = rewriter.create<tensor::ExtractOp>(
2275  loc, table, ValueRange{indexPlusOne});
2276 
2277  base =
2278  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2279  next =
2280  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2281 
2282  // Use the fractional part to interpolate between the input values:
2283  // result = (base << 7) + (next - base) * fraction
2284  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2285  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2286  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2287  Value result =
2288  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2289 
2290  rewriter.create<linalg::YieldOp>(loc, result);
2291 
2292  return success();
2293  }
2294  }
2295 
2296  return rewriter.notifyMatchFailure(
2297  op, "unable to create body for tosa.table op");
2298  }
2299 };
2300 
2301 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2303 
2304  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2305 
2306  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2307  OpFoldResult ofr) {
2308  auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2309  auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2310 
2311  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2312  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2313  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2314  return getAsOpFoldResult(plusOne);
2315  }
2316 
2317  static RankedTensorType
2318  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2319  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2320  // Get [N, H, W]
2321  auto dims = tensor::getMixedSizes(builder, loc, input);
2322 
2323  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2324  // output tensors.
2325  dims[2] = halfPlusOne(builder, loc, dims[2]);
2326 
2327  llvm::SmallVector<int64_t, 3> staticSizes;
2328  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2329 
2330  auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2331  return RankedTensorType::get(staticSizes, elementType);
2332  }
2333 
2334  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2335  RankedTensorType type,
2336  llvm::ArrayRef<Value> dynamicSizes) {
2337  auto emptyTensor =
2338  rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2339  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2340  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2341  auto filledTensor = rewriter
2342  .create<linalg::FillOp>(loc, ValueRange{fillValue},
2343  ValueRange{emptyTensor})
2344  .result();
2345  return filledTensor;
2346  }
2347 
2348  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2349  FloatType type, Value value) {
2350  auto integerVal = builder.create<arith::IndexCastUIOp>(
2351  loc,
2352  type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2353  : builder.getI32Type(),
2354  value);
2355 
2356  return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2357  }
2358 
2359  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2360  FloatType type, int64_t index) {
2361  auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2362  return castIndexToFloat(builder, loc, type, indexVal);
2363  }
2364 
2365  template <typename... Args>
2366  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2367  Args... args) {
2368  return {builder.getAffineDimExpr(args)...};
2369  }
2370 
2371  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2372  PatternRewriter &rewriter) const override {
2373  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2374  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2375  return rewriter.notifyMatchFailure(rfft2d,
2376  "only supports ranked tensors");
2377  }
2378 
2379  auto loc = rfft2d.getLoc();
2380  auto input = rfft2d.getInput();
2381  auto elementType =
2382  dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2383  if (!elementType)
2384  return rewriter.notifyMatchFailure(rfft2d,
2385  "only supports float element types");
2386 
2387  // Compute the output type and set of dynamic sizes
2388  llvm::SmallVector<Value> dynamicSizes;
2389  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2390 
2391  // Iterator types for the linalg.generic implementation
2393  utils::IteratorType::parallel, utils::IteratorType::parallel,
2394  utils::IteratorType::parallel, utils::IteratorType::reduction,
2395  utils::IteratorType::reduction};
2396 
2397  // Inputs/outputs to the linalg.generic implementation
2398  llvm::SmallVector<Value> genericOpInputs = {input};
2399  llvm::SmallVector<Value> genericOpOutputs = {
2400  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2401  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2402 
2403  // Indexing maps for input and output tensors
2404  auto indexingMaps = AffineMap::inferFromExprList(
2405  llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2406  affineDimsExpr(rewriter, 0, 1, 2),
2407  affineDimsExpr(rewriter, 0, 1, 2)},
2408  rewriter.getContext());
2409 
2410  // Width and height dimensions of the original input.
2411  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2412  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2413 
2414  // Constants and dimension sizes
2415  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2416  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2417  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2418  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2419 
2420  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2421  Value valReal = args[0];
2422  Value sumReal = args[1];
2423  Value sumImag = args[2];
2424 
2425  // Indices for angle computation
2426  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2427  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2428  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2429  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2430 
2431  // Calculating angle without integer parts of components as sin/cos are
2432  // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2433  // / W);
2434  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2435  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2436 
2437  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2438  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2439 
2440  auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2441  auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2442 
2443  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2444  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2445  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2446  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2447 
2448  // realComponent = valReal * cos(angle)
2449  // imagComponent = valReal * sin(angle)
2450  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2451  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2452  auto realComponent =
2453  builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2454  auto imagComponent =
2455  builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2456 
2457  // outReal = sumReal + realComponent
2458  // outImag = sumImag - imagComponent
2459  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2460  auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2461 
2462  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2463  };
2464 
2465  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2466  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2467  indexingMaps, iteratorTypes, buildBody);
2468 
2469  return success();
2470  }
2471 };
2472 
2473 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2475 
2476  LogicalResult matchAndRewrite(FFT2dOp fft2d,
2477  PatternRewriter &rewriter) const override {
2478  if (!llvm::all_of(fft2d->getOperandTypes(),
2479  RFFT2dConverter::isRankedTensor) ||
2480  !llvm::all_of(fft2d->getResultTypes(),
2481  RFFT2dConverter::isRankedTensor)) {
2482  return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2483  }
2484 
2485  Location loc = fft2d.getLoc();
2486  Value input_real = fft2d.getInputReal();
2487  Value input_imag = fft2d.getInputImag();
2488  BoolAttr inverse = fft2d.getInverseAttr();
2489 
2490  auto real_el_ty = cast<FloatType>(
2491  cast<ShapedType>(input_real.getType()).getElementType());
2492  [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2493  cast<ShapedType>(input_imag.getType()).getElementType());
2494 
2495  assert(real_el_ty == imag_el_ty);
2496 
2497  // Compute the output type and set of dynamic sizes
2498  SmallVector<Value> dynamicSizes;
2499 
2500  // Get [N, H, W]
2501  auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2502 
2503  SmallVector<int64_t, 3> staticSizes;
2504  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2505 
2506  auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2507 
2508  // Iterator types for the linalg.generic implementation
2509  SmallVector<utils::IteratorType, 5> iteratorTypes = {
2510  utils::IteratorType::parallel, utils::IteratorType::parallel,
2511  utils::IteratorType::parallel, utils::IteratorType::reduction,
2512  utils::IteratorType::reduction};
2513 
2514  // Inputs/outputs to the linalg.generic implementation
2515  SmallVector<Value> genericOpInputs = {input_real, input_imag};
2516  SmallVector<Value> genericOpOutputs = {
2517  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2518  dynamicSizes),
2519  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2520  dynamicSizes)};
2521 
2522  // Indexing maps for input and output tensors
2523  auto indexingMaps = AffineMap::inferFromExprList(
2524  ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2525  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2526  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2527  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2528  rewriter.getContext());
2529 
2530  // Width and height dimensions of the original input.
2531  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2532  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2533 
2534  // Constants and dimension sizes
2535  auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2536  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2537  Value constH =
2538  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2539  Value constW =
2540  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2541 
2542  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2543  Value valReal = args[0];
2544  Value valImag = args[1];
2545  Value sumReal = args[2];
2546  Value sumImag = args[3];
2547 
2548  // Indices for angle computation
2549  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2550  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2551  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2552  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2553 
2554  // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2555  // ox) % W ) / W);
2556  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2557  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2558 
2559  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2560  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2561 
2562  auto iyRemFloat =
2563  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2564  auto ixRemFloat =
2565  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2566 
2567  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2568  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2569 
2570  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2571  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2572 
2573  if (inverse.getValue()) {
2574  angle = builder.create<arith::MulFOp>(
2575  loc, angle,
2576  rewriter.create<arith::ConstantOp>(
2577  loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2578  }
2579 
2580  // realComponent = val_real * cos(a) + val_imag * sin(a);
2581  // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2582  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2583  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2584 
2585  auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2586  auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2587  auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2588 
2589  auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2590  auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2591 
2592  auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2593 
2594  // outReal = sumReal + realComponent
2595  // outImag = sumImag - imagComponent
2596  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2597  auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2598 
2599  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2600  };
2601 
2602  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2603  fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2604  indexingMaps, iteratorTypes, buildBody);
2605 
2606  return success();
2607  }
2608 };
2609 
2610 } // namespace
2611 
2613  const TypeConverter &converter, RewritePatternSet *patterns) {
2614 
2615  // We have multiple resize coverters to handle degenerate cases.
2616  patterns->add<GenericResizeConverter>(patterns->getContext(),
2617  /*benefit=*/100);
2618  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2619  /*benefit=*/200);
2620  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2621  /*benefit=*/300);
2622 
2623  patterns->add<
2624  // clang-format off
2625  PointwiseConverter<tosa::AddOp>,
2626  PointwiseConverter<tosa::SubOp>,
2627  PointwiseConverter<tosa::MulOp>,
2628  PointwiseConverter<tosa::IntDivOp>,
2629  PointwiseConverter<tosa::NegateOp>,
2630  PointwiseConverter<tosa::PowOp>,
2631  PointwiseConverter<tosa::ReciprocalOp>,
2632  PointwiseConverter<tosa::RsqrtOp>,
2633  PointwiseConverter<tosa::LogOp>,
2634  PointwiseConverter<tosa::ExpOp>,
2635  PointwiseConverter<tosa::AbsOp>,
2636  PointwiseConverter<tosa::SinOp>,
2637  PointwiseConverter<tosa::CosOp>,
2638  PointwiseConverter<tosa::TanhOp>,
2639  PointwiseConverter<tosa::ErfOp>,
2640  PointwiseConverter<tosa::BitwiseAndOp>,
2641  PointwiseConverter<tosa::BitwiseOrOp>,
2642  PointwiseConverter<tosa::BitwiseNotOp>,
2643  PointwiseConverter<tosa::BitwiseXorOp>,
2644  PointwiseConverter<tosa::LogicalAndOp>,
2645  PointwiseConverter<tosa::LogicalNotOp>,
2646  PointwiseConverter<tosa::LogicalOrOp>,
2647  PointwiseConverter<tosa::LogicalXorOp>,
2648  PointwiseConverter<tosa::CastOp>,
2649  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2650  PointwiseConverter<tosa::LogicalRightShiftOp>,
2651  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2652  PointwiseConverter<tosa::ClzOp>,
2653  PointwiseConverter<tosa::SelectOp>,
2654  PointwiseConverter<tosa::GreaterOp>,
2655  PointwiseConverter<tosa::GreaterEqualOp>,
2656  PointwiseConverter<tosa::EqualOp>,
2657  PointwiseConverter<tosa::MaximumOp>,
2658  PointwiseConverter<tosa::MinimumOp>,
2659  PointwiseConverter<tosa::CeilOp>,
2660  PointwiseConverter<tosa::FloorOp>,
2661  PointwiseConverter<tosa::ClampOp>,
2662  PointwiseConverter<tosa::SigmoidOp>
2663  >(converter, patterns->getContext());
2664 
2665  patterns->add<
2666  IdentityNConverter<tosa::IdentityOp>,
2667  ReduceConverter<tosa::ReduceAllOp>,
2668  ReduceConverter<tosa::ReduceAnyOp>,
2669  ReduceConverter<tosa::ReduceMinOp>,
2670  ReduceConverter<tosa::ReduceMaxOp>,
2671  ReduceConverter<tosa::ReduceSumOp>,
2672  ReduceConverter<tosa::ReduceProdOp>,
2673  ArgMaxConverter,
2674  GatherConverter,
2675  RescaleConverter,
2676  ReverseConverter,
2677  RFFT2dConverter,
2678  FFT2dConverter,
2679  TableConverter,
2680  TileConverter>(patterns->getContext());
2681  // clang-format on
2682 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, ValueRange operands, int64_t rank)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
const float * table
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:80
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:47
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:114
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:149
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:59
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:69
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362