MLIR  20.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 
34 #include <numeric>
35 
36 using namespace mlir;
37 using namespace mlir::tosa;
38 
39 template <typename T>
40 static arith::ConstantOp
41 createConstFromIntAttribute(Operation *op, const std::string &attrName,
42  Type requiredAttrType, OpBuilder &rewriter) {
43  auto castedN = static_cast<T>(
44  cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
45  return rewriter.create<arith::ConstantOp>(
46  op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
47 }
48 
50  Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51  ConversionPatternRewriter &rewriter) {
52  Location loc = op->getLoc();
53  auto elementTy =
54  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
55 
56  // tosa::AbsOp
57  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
58  return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
59 
60  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
61  auto zero = rewriter.create<arith::ConstantOp>(
62  loc, rewriter.getZeroAttr(elementTy));
63  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
64  return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
65  }
66 
67  // tosa::AddOp
68  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
69  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
70 
71  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
72  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
73 
74  // tosa::SubOp
75  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
76  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
77 
78  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
79  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
80 
81  // tosa::IntDivOp
82  if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
83  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
84 
85  // tosa::ReciprocalOp
86  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
87  auto one =
88  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
89  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
90  }
91 
92  // tosa::MulOp
93  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
94  return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
95 
96  if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
97  Value a = args[0];
98  Value b = args[1];
99  auto shift =
100  cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
101  if (shift > 0) {
102  auto shiftConst =
103  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
104  if (!a.getType().isInteger(32))
105  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
106 
107  if (!b.getType().isInteger(32))
108  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
109 
110  auto result = rewriter.create<tosa::ApplyScaleOp>(
111  loc, rewriter.getI32Type(), a, b, shiftConst,
112  rewriter.getBoolAttr(false));
113 
114  if (elementTy.isInteger(32))
115  return result;
116 
117  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
118  }
119 
120  int aWidth = a.getType().getIntOrFloatBitWidth();
121  int bWidth = b.getType().getIntOrFloatBitWidth();
122  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
123 
124  if (aWidth < cWidth)
125  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
126  if (bWidth < cWidth)
127  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
128 
129  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
130  }
131 
132  // tosa::NegateOp
133  if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
134  return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
135 
136  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
137  int64_t inZp = 0, outZp = 0;
138 
139  if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
140  auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
141  inZp = quantizationInfo.value().getInputZp();
142  outZp = quantizationInfo.value().getOutputZp();
143  }
144 
145  int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
146  if (!inZp && !outZp) {
147  auto constant = rewriter.create<arith::ConstantOp>(
148  loc, IntegerAttr::get(elementTy, 0));
149  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
150  args[0]);
151  }
152 
153  // Compute the maximum value that can occur in the intermediate buffer.
154  int64_t zpAdd = inZp + outZp;
155  int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
156  std::abs(zpAdd) + 1;
157 
158  // Convert that maximum value into the maximum bitwidth needed to represent
159  // it. We assume 48-bit numbers may be supported further in the pipeline.
160  int intermediateBitWidth = 64;
161  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
162  intermediateBitWidth = 16;
163  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
164  intermediateBitWidth = 32;
165  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
166  intermediateBitWidth = 48;
167  }
168 
169  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
170  Value zpAddValue = rewriter.create<arith::ConstantOp>(
171  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
172 
173  // The negation can be applied by doing:
174  // outputValue = inZp + outZp - inputValue
175  auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
176  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
177 
178  // Clamp to the negation range.
179  Value min = rewriter.create<arith::ConstantIntOp>(
180  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
181  intermediateType);
182  Value max = rewriter.create<arith::ConstantIntOp>(
183  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
184  intermediateType);
185  auto clamp =
186  clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
187 
188  // Truncate to the final value.
189  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
190  }
191 
192  // tosa::BitwiseAndOp
193  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
194  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
195 
196  // tosa::BitwiseOrOp
197  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
198  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
199 
200  // tosa::BitwiseNotOp
201  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
202  auto allOnesAttr = rewriter.getIntegerAttr(
203  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
204  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
205  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
206  }
207 
208  // tosa::BitwiseXOrOp
209  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
210  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
211 
212  // tosa::LogicalLeftShiftOp
213  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
214  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
215 
216  // tosa::LogicalRightShiftOp
217  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
218  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
219 
220  // tosa::ArithmeticRightShiftOp
221  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
222  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
223  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
224  if (!round) {
225  return result;
226  }
227 
228  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
229  auto one =
230  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
231  auto zero =
232  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
233  auto i1one =
234  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
235 
236  // Checking that input2 != 0
237  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
238  loc, arith::CmpIPredicate::sgt, args[1], zero);
239 
240  // Checking for the last bit of input1 to be 1
241  auto subtract =
242  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
243  auto shifted =
244  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
245  ->getResults();
246  auto truncated =
247  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
248  auto isInputOdd =
249  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
250 
251  auto shouldRound = rewriter.create<arith::AndIOp>(
252  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
253  auto extended =
254  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
255  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
256  }
257 
258  // tosa::ClzOp
259  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
260  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
261  }
262 
263  // tosa::LogicalAnd
264  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
265  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
266 
267  // tosa::LogicalNot
268  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
269  auto one = rewriter.create<arith::ConstantOp>(
270  loc, rewriter.getIntegerAttr(elementTy, 1));
271  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
272  }
273 
274  // tosa::LogicalOr
275  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
276  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
277 
278  // tosa::LogicalXor
279  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
280  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
281 
282  // tosa::PowOp
283  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
284  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
285 
286  // tosa::RsqrtOp
287  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
288  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
289 
290  // tosa::LogOp
291  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
292  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
293 
294  // tosa::ExpOp
295  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
296  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
297 
298  // tosa::SinOp
299  if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
300  return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
301 
302  // tosa::CosOp
303  if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
304  return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
305 
306  // tosa::TanhOp
307  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
308  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
309 
310  // tosa::ErfOp
311  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
312  return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
313 
314  // tosa::GreaterOp
315  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
316  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
317  args[0], args[1]);
318 
319  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
320  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
321  args[0], args[1]);
322 
323  // tosa::GreaterEqualOp
324  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
325  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
326  args[0], args[1]);
327 
328  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
329  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
330  args[0], args[1]);
331 
332  // tosa::EqualOp
333  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
334  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
335  args[0], args[1]);
336 
337  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
338  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
339  args[0], args[1]);
340 
341  // tosa::SelectOp
342  if (isa<tosa::SelectOp>(op)) {
343  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
344  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
345  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
346  }
347 
348  // tosa::MaximumOp
349  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
350  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
351  }
352 
353  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
354  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
355  }
356 
357  // tosa::MinimumOp
358  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
359  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
360  }
361 
362  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
363  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
364  }
365 
366  // tosa::CeilOp
367  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
368  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
369 
370  // tosa::FloorOp
371  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
372  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
373 
374  // tosa::ClampOp
375  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
376  bool losesInfo = false;
377  APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
378  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
379  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
380  APFloat::rmNearestTiesToEven, &losesInfo);
381  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
382  APFloat::rmNearestTiesToEven, &losesInfo);
383  auto min = rewriter.create<arith::ConstantOp>(
384  loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
385  auto max = rewriter.create<arith::ConstantOp>(
386  loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
387  return clampFloatHelper(loc, args[0], min, max, rewriter);
388  }
389 
390  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
391  auto intTy = cast<IntegerType>(elementTy);
392  int64_t min =
393  cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
394  int64_t max =
395  cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
396 
397  int64_t minRepresentable = std::numeric_limits<int64_t>::min();
398  int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
399  if (intTy.isUnsignedInteger()) {
400  minRepresentable = 0;
401  if (intTy.getIntOrFloatBitWidth() <= 63) {
402  maxRepresentable =
403  (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
404  .getZExtValue();
405  }
406  } else if (intTy.getIntOrFloatBitWidth() <= 64) {
407  // Ensure that min & max fit into signed n-bit constants.
408  minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
409  .getSExtValue();
410  maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
411  .getSExtValue();
412  }
413  // Ensure that the bounds are representable as n-bit signed/unsigned
414  // integers.
415  min = std::max(min, minRepresentable);
416  max = std::max(max, minRepresentable);
417  min = std::min(min, maxRepresentable);
418  max = std::min(max, maxRepresentable);
419 
420  auto minVal = rewriter.create<arith::ConstantIntOp>(
421  loc, min, intTy.getIntOrFloatBitWidth());
422  auto maxVal = rewriter.create<arith::ConstantIntOp>(
423  loc, max, intTy.getIntOrFloatBitWidth());
424  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
425  intTy.isUnsignedInteger());
426  }
427 
428  // tosa::SigmoidOp
429  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
430  auto one =
431  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
432  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
433  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
434  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
435  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
436  }
437 
438  // tosa::CastOp
439  if (isa<tosa::CastOp>(op)) {
440  Type srcTy = elementTy;
441  Type dstTy = resultTypes.front();
442  bool bitExtend =
444 
445  if (srcTy == dstTy)
446  return args.front();
447 
448  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
449  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
450  std::nullopt);
451 
452  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
453  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
454  std::nullopt);
455 
456  // 1-bit integers need to be treated as signless.
457  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
458  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
459  std::nullopt);
460 
461  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
462  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
463  std::nullopt);
464 
465  // Unsigned integers need an unrealized cast so that they can be passed
466  // to UIToFP.
467  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
468  auto unrealizedCast =
469  rewriter
470  .create<UnrealizedConversionCastOp>(
471  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
472  args[0])
473  .getResult(0);
474  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
475  unrealizedCast);
476  }
477 
478  // All other si-to-fp conversions should be handled by SIToFP.
479  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
480  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
481  std::nullopt);
482 
483  // Casting to boolean, floats need to only be checked as not-equal to zero.
484  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
485  Value zero = rewriter.create<arith::ConstantOp>(
486  loc, rewriter.getFloatAttr(srcTy, 0.0));
487  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
488  args.front(), zero);
489  }
490 
491  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
492  auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
493 
494  const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
495  // Check whether neither int min nor int max can be represented in the
496  // input floating-point type due to too short exponent range.
497  if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
498  APFloat::semanticsMaxExponent(fltSemantics)) {
499  // Use cmp + select to replace infinites by int min / int max. Other
500  // integral values can be represented in the integer space.
501  auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
502  auto posInf = rewriter.create<arith::ConstantOp>(
503  loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
504  APFloat::getInf(fltSemantics)));
505  auto negInf = rewriter.create<arith::ConstantOp>(
506  loc, rewriter.getFloatAttr(
507  getElementTypeOrSelf(srcTy),
508  APFloat::getInf(fltSemantics, /*Negative=*/true)));
509  auto overflow = rewriter.create<arith::CmpFOp>(
510  loc, arith::CmpFPredicate::UEQ, rounded, posInf);
511  auto underflow = rewriter.create<arith::CmpFOp>(
512  loc, arith::CmpFPredicate::UEQ, rounded, negInf);
513  auto intMin = rewriter.create<arith::ConstantOp>(
514  loc, rewriter.getIntegerAttr(
515  getElementTypeOrSelf(dstTy),
516  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
517  auto intMax = rewriter.create<arith::ConstantOp>(
518  loc, rewriter.getIntegerAttr(
519  getElementTypeOrSelf(dstTy),
520  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
521  auto maxClamped =
522  rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
523  return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
524  maxClamped);
525  }
526 
527  auto intMinFP = rewriter.create<arith::ConstantOp>(
528  loc, rewriter.getFloatAttr(
529  getElementTypeOrSelf(srcTy),
530  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
531  .getSExtValue()));
532 
533  // Check whether the mantissa has enough bits to represent int max.
534  if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
535  dstTy.getIntOrFloatBitWidth() - 1) {
536  // Int min can also be represented since it is a power of two and thus
537  // consists of a single leading bit. Therefore we can clamp the input
538  // in the floating-point domain.
539 
540  auto intMaxFP = rewriter.create<arith::ConstantOp>(
541  loc, rewriter.getFloatAttr(
542  getElementTypeOrSelf(srcTy),
543  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
544  .getSExtValue()));
545 
546  Value clamped =
547  clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
548  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
549  }
550 
551  // Due to earlier check we know exponant range is big enough to represent
552  // int min. We can therefore rely on int max + 1 being representable as
553  // well because it's just int min with a positive sign. So clamp the min
554  // value and compare against that to select the max int value if needed.
555  auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
556  loc, rewriter.getFloatAttr(
557  getElementTypeOrSelf(srcTy),
558  static_cast<double>(
559  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
560  .getSExtValue()) +
561  1.0f));
562 
563  auto intMax = rewriter.create<arith::ConstantOp>(
564  loc, rewriter.getIntegerAttr(
565  getElementTypeOrSelf(dstTy),
566  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
567  auto minClampedFP =
568  rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
569  auto minClamped =
570  rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
571  auto overflow = rewriter.create<arith::CmpFOp>(
572  loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
573  return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
574  minClamped);
575  }
576 
577  // Casting to boolean, integers need to only be checked as not-equal to
578  // zero.
579  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
580  Value zero = rewriter.create<arith::ConstantIntOp>(
581  loc, 0, srcTy.getIntOrFloatBitWidth());
582  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
583  args.front(), zero);
584  }
585 
586  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
587  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
588  std::nullopt);
589 
590  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
591  return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
592  }
593  }
594 
595  (void)rewriter.notifyMatchFailure(
596  op, "unhandled op for linalg body calculation for elementwise op");
597  return nullptr;
598 }
599 
600 static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
601  int64_t rank) {
602  // No need to expand if we are already at the desired rank
603  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
604  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
605  int64_t numExtraDims = rank - shapedType.getRank();
606  assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
607  if (!numExtraDims)
608  return tensor;
609 
610  // Compute reassociation indices
611  SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
612  shapedType.getRank());
613  int64_t index = 0;
614  for (index = 0; index <= numExtraDims; index++)
615  reassociationIndices[0].push_back(index);
616  for (size_t position = 1; position < reassociationIndices.size(); position++)
617  reassociationIndices[position].push_back(index++);
618 
619  // Compute result type
620  SmallVector<int64_t> resultShape;
621  for (index = 0; index < numExtraDims; index++)
622  resultShape.push_back(1);
623  for (auto size : shapedType.getShape())
624  resultShape.push_back(size);
625  auto resultType =
626  RankedTensorType::get(resultShape, shapedType.getElementType());
627 
628  // Emit 'tensor.expand_shape' op
629  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
630  reassociationIndices);
631 }
632 
634  Location loc, ValueRange operands,
635  int64_t rank) {
636  return llvm::map_to_vector(operands, [&](Value operand) {
637  return expandRank(rewriter, loc, operand, rank);
638  });
639 }
640 
642 
643 // Emit an 'arith.constant' op for the given index if it has not been created
644 // yet, or return an existing constant. This will prevent an excessive creation
645 // of redundant constants, easing readability of emitted code for unit tests.
647  IndexPool &indexPool, int64_t index) {
648  auto [it, inserted] = indexPool.try_emplace(index);
649  if (inserted)
650  it->second =
651  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
652  return it->second;
653 }
654 
656  IndexPool &indexPool, Value tensor, int64_t index) {
657  auto indexValue = createIndex(rewriter, loc, indexPool, index);
658  return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
659 }
660 
662  IndexPool &indexPool, Value tensor,
663  int64_t index) {
664  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
665  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
666  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
667  if (shapedType.isDynamicDim(index))
668  return getTensorDim(rewriter, loc, indexPool, tensor, index);
669  return rewriter.getIndexAttr(shapedType.getDimSize(index));
670 }
671 
672 static bool operandsAndResultsRanked(Operation *operation) {
673  auto isRanked = [](Value value) {
674  return isa<RankedTensorType>(value.getType());
675  };
676  return llvm::all_of(operation->getOperands(), isRanked) &&
677  llvm::all_of(operation->getResults(), isRanked);
678 }
679 
680 // Compute the runtime dimension size for dimension 'dim' of the output by
681 // inspecting input 'operands', all of which are expected to have the same rank.
682 // This function returns a pair {targetSize, masterOperand}.
683 //
684 // The runtime size of the output dimension is returned either as a statically
685 // computed attribute or as a runtime SSA value.
686 //
687 // If the target size was inferred directly from one dominating operand, that
688 // operand is returned in 'masterOperand'. If the target size is inferred from
689 // multiple operands, 'masterOperand' is set to nullptr.
690 static std::pair<OpFoldResult, Value>
692  ValueRange operands, int64_t dim) {
693  // If any input operand contains a static size greater than 1 for this
694  // dimension, that is the target size. An occurrence of an additional static
695  // dimension greater than 1 with a different value is undefined behavior.
696  for (auto operand : operands) {
697  auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
698  if (!ShapedType::isDynamic(size) && size > 1)
699  return {rewriter.getIndexAttr(size), operand};
700  }
701 
702  // Filter operands with dynamic dimension
703  auto operandsWithDynamicDim =
704  llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
705  return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
706  }));
707 
708  // If no operand has a dynamic dimension, it means all sizes were 1
709  if (operandsWithDynamicDim.empty())
710  return {rewriter.getIndexAttr(1), operands.front()};
711 
712  // Emit code that computes the runtime size for this dimension. If there is
713  // only one operand with a dynamic dimension, it is considered the master
714  // operand that determines the runtime size of the output dimension.
715  auto targetSize =
716  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
717  if (operandsWithDynamicDim.size() == 1)
718  return {targetSize, operandsWithDynamicDim[0]};
719 
720  // Calculate maximum size among all dynamic dimensions
721  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
722  auto nextSize =
723  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
724  targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
725  }
726  return {targetSize, nullptr};
727 }
728 
729 // Compute the runtime output size for all dimensions. This function returns
730 // a pair {targetShape, masterOperands}.
731 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
733  IndexPool &indexPool, ValueRange operands) {
734  assert(!operands.empty());
735  auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
736  SmallVector<OpFoldResult> targetShape;
737  SmallVector<Value> masterOperands;
738  for (auto dim : llvm::seq<int64_t>(0, rank)) {
739  auto [targetSize, masterOperand] =
740  computeTargetSize(rewriter, loc, indexPool, operands, dim);
741  targetShape.push_back(targetSize);
742  masterOperands.push_back(masterOperand);
743  }
744  return {targetShape, masterOperands};
745 }
746 
748  IndexPool &indexPool, Value operand,
749  int64_t dim, OpFoldResult targetSize,
750  Value masterOperand) {
751  // Nothing to do if this is a static dimension
752  auto rankedTensorType = cast<RankedTensorType>(operand.getType());
753  if (!rankedTensorType.isDynamicDim(dim))
754  return operand;
755 
756  // If the target size for this dimension was directly inferred by only taking
757  // this operand into account, there is no need to broadcast. This is an
758  // optimization that will prevent redundant control flow, and constitutes the
759  // main motivation for tracking "master operands".
760  if (operand == masterOperand)
761  return operand;
762 
763  // Affine maps for 'linalg.generic' op
764  auto rank = rankedTensorType.getRank();
765  SmallVector<AffineExpr> affineExprs;
766  for (auto index : llvm::seq<int64_t>(0, rank)) {
767  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
768  : rewriter.getAffineDimExpr(index);
769  affineExprs.push_back(affineExpr);
770  }
771  auto broadcastAffineMap =
772  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
773  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
774  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
775 
776  // Check if broadcast is necessary
777  auto one = createIndex(rewriter, loc, indexPool, 1);
778  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
779  auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
780  loc, arith::CmpIPredicate::eq, runtimeSize, one);
781 
782  // Emit 'then' region of 'scf.if'
783  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
784  // It is not safe to cache constants across regions.
785  // New constants could potentially violate dominance requirements.
786  IndexPool localPool;
787 
788  // Emit 'tensor.empty' op
789  SmallVector<OpFoldResult> outputTensorShape;
790  for (auto index : llvm::seq<int64_t>(0, rank)) {
791  auto size = index == dim ? targetSize
792  : getOrFoldTensorDim(rewriter, loc, localPool,
793  operand, index);
794  outputTensorShape.push_back(size);
795  }
796  Value outputTensor = opBuilder.create<tensor::EmptyOp>(
797  loc, outputTensorShape, rankedTensorType.getElementType());
798 
799  // Emit 'linalg.generic' op
800  auto resultTensor =
801  opBuilder
802  .create<linalg::GenericOp>(
803  loc, outputTensor.getType(), operand, outputTensor, affineMaps,
805  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
806  // Emit 'linalg.yield' op
807  opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
808  })
809  .getResult(0);
810 
811  // Cast to original operand type if necessary
812  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
813  loc, operand.getType(), resultTensor);
814 
815  // Emit 'scf.yield' op
816  opBuilder.create<scf::YieldOp>(loc, castResultTensor);
817  };
818 
819  // Emit 'else' region of 'scf.if'
820  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
821  opBuilder.create<scf::YieldOp>(loc, operand);
822  };
823 
824  // Emit 'scf.if' op
825  auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
826  emitThenRegion, emitElseRegion);
827  return ifOp.getResult(0);
828 }
829 
831  IndexPool &indexPool, Value operand,
832  ArrayRef<OpFoldResult> targetShape,
833  ArrayRef<Value> masterOperands) {
834  int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
835  assert((int64_t)targetShape.size() == rank);
836  assert((int64_t)masterOperands.size() == rank);
837  for (auto index : llvm::seq<int64_t>(0, rank))
838  operand =
839  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
840  targetShape[index], masterOperands[index]);
841  return operand;
842 }
843 
844 static SmallVector<Value>
846  IndexPool &indexPool, ValueRange operands,
847  ArrayRef<OpFoldResult> targetShape,
848  ArrayRef<Value> masterOperands) {
849  // No need to broadcast for unary operations
850  if (operands.size() == 1)
851  return operands;
852 
853  // Broadcast dynamic dimensions operand by operand
854  return llvm::map_to_vector(operands, [&](Value operand) {
855  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
856  targetShape, masterOperands);
857  });
858 }
859 
860 static LogicalResult
862  Operation *operation, ValueRange operands,
863  ArrayRef<OpFoldResult> targetShape,
864  const TypeConverter &converter) {
865  // Generate output tensor
866  auto resultType = cast_or_null<RankedTensorType>(
867  converter.convertType(operation->getResultTypes().front()));
868  if (!resultType) {
869  return rewriter.notifyMatchFailure(operation, "failed to convert type");
870  }
871  Value outputTensor = rewriter.create<tensor::EmptyOp>(
872  loc, targetShape, resultType.getElementType());
873 
874  // Create affine maps. Input affine maps broadcast static dimensions of size
875  // 1. The output affine map is an identity map.
876  //
877  auto rank = resultType.getRank();
878  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
879  auto shape = cast<ShapedType>(operand.getType()).getShape();
880  SmallVector<AffineExpr> affineExprs;
881  for (auto it : llvm::enumerate(shape)) {
882  auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
883  : rewriter.getAffineDimExpr(it.index());
884  affineExprs.push_back(affineExpr);
885  }
886  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
887  });
888  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
889 
890  // Emit 'linalg.generic' op
891  bool encounteredError = false;
892  auto linalgOp = rewriter.create<linalg::GenericOp>(
893  loc, outputTensor.getType(), operands, outputTensor, affineMaps,
895  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
897  operation, blockArgs.take_front(operation->getNumOperands()),
898  {resultType.getElementType()}, rewriter);
899  if (!opResult) {
900  encounteredError = true;
901  return;
902  }
903  opBuilder.create<linalg::YieldOp>(loc, opResult);
904  });
905  if (encounteredError)
906  return rewriter.notifyMatchFailure(
907  operation, "unable to create linalg.generic body for elementwise op");
908 
909  // Cast 'linalg.generic' result into original result type if needed
910  auto castResult = rewriter.createOrFold<tensor::CastOp>(
911  loc, resultType, linalgOp->getResult(0));
912  rewriter.replaceOp(operation, castResult);
913  return success();
914 }
915 
916 static LogicalResult
918  ConversionPatternRewriter &rewriter,
919  const TypeConverter &converter) {
920 
921  // Collect op properties
922  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
923  assert(operation->getNumOperands() >= 1 &&
924  "elementwise op expects at least 1 operand");
925  if (!operandsAndResultsRanked(operation))
926  return rewriter.notifyMatchFailure(operation,
927  "Unranked tensors not supported");
928 
929  // Lower operation
930  IndexPool indexPool;
931  auto loc = operation->getLoc();
932  auto rank =
933  cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
934  auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
935  auto [targetShape, masterOperands] =
936  computeTargetShape(rewriter, loc, indexPool, expandedOperands);
937  auto broadcastOperands = broadcastDynamicDimensions(
938  rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
939  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
940  targetShape, converter);
941 }
942 
943 // Returns the constant initial value for a given reduction operation. The
944 // attribute type varies depending on the element type required.
945 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
946  PatternRewriter &rewriter) {
947  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
948  return rewriter.getFloatAttr(elementTy, 0.0);
949 
950  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
951  return rewriter.getIntegerAttr(elementTy, 0);
952 
953  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
954  return rewriter.getFloatAttr(elementTy, 1.0);
955 
956  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
957  return rewriter.getIntegerAttr(elementTy, 1);
958 
959  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
960  return rewriter.getFloatAttr(
961  elementTy, APFloat::getLargest(
962  cast<FloatType>(elementTy).getFloatSemantics(), false));
963 
964  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
965  return rewriter.getIntegerAttr(
966  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
967 
968  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
969  return rewriter.getFloatAttr(
970  elementTy, APFloat::getLargest(
971  cast<FloatType>(elementTy).getFloatSemantics(), true));
972 
973  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
974  return rewriter.getIntegerAttr(
975  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
976 
977  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
978  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
979 
980  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
981  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
982 
983  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
984  return rewriter.getFloatAttr(
985  elementTy, APFloat::getLargest(
986  cast<FloatType>(elementTy).getFloatSemantics(), true));
987 
988  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
989  return rewriter.getIntegerAttr(
990  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
991 
992  return {};
993 }
994 
995 // Creates the body calculation for a reduction. The operations vary depending
996 // on the input type.
998  ValueRange args,
999  Type elementTy,
1000  PatternRewriter &rewriter) {
1001  Location loc = op->getLoc();
1002  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1003  return rewriter.create<arith::AddFOp>(loc, args);
1004  }
1005 
1006  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1007  return rewriter.create<arith::AddIOp>(loc, args);
1008  }
1009 
1010  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
1011  return rewriter.create<arith::MulFOp>(loc, args);
1012  }
1013 
1014  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
1015  return rewriter.create<arith::MulIOp>(loc, args);
1016  }
1017 
1018  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1019  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1020  }
1021 
1022  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1023  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1024  }
1025 
1026  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1027  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1028  }
1029 
1030  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1031  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1032  }
1033 
1034  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1035  return rewriter.create<arith::AndIOp>(loc, args);
1036 
1037  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1038  return rewriter.create<arith::OrIOp>(loc, args);
1039 
1040  return {};
1041 }
1042 
1043 // Performs the match and rewrite for reduction operations. This includes
1044 // declaring a correctly sized initial value, and the linalg.generic operation
1045 // that reduces across the specified axis.
1046 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1047  PatternRewriter &rewriter) {
1048  auto loc = op->getLoc();
1049  auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
1050  auto resultTy = cast<ShapedType>(op->getResult(0).getType());
1051  auto elementTy = resultTy.getElementType();
1052  Value input = op->getOperand(0);
1053 
1054  SmallVector<int64_t> reduceShape;
1055  SmallVector<Value> dynDims;
1056  for (unsigned i = 0; i < inputTy.getRank(); i++) {
1057  if (axis != i) {
1058  reduceShape.push_back(inputTy.getDimSize(i));
1059  if (inputTy.isDynamicDim(i))
1060  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1061  }
1062  }
1063 
1064  // First fill the output buffer with the init value.
1065  auto emptyTensor =
1066  rewriter
1067  .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1068  dynDims)
1069  .getResult();
1070 
1071  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1072  if (!fillValueAttr)
1073  return rewriter.notifyMatchFailure(
1074  op, "No initial value found for reduction operation");
1075 
1076  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1077  auto filledTensor = rewriter
1078  .create<linalg::FillOp>(loc, ValueRange{fillValue},
1079  ValueRange{emptyTensor})
1080  .result();
1081 
1082  bool didEncounterError = false;
1083  auto linalgOp = rewriter.create<linalg::ReduceOp>(
1084  loc, input, filledTensor, axis,
1085  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1087  op, blockArgs, elementTy, rewriter);
1088  if (result)
1089  didEncounterError = true;
1090 
1091  nestedBuilder.create<linalg::YieldOp>(loc, result);
1092  });
1093 
1094  if (!didEncounterError)
1095  return rewriter.notifyMatchFailure(
1096  op, "unable to create linalg.generic body for reduce op");
1097 
1098  SmallVector<ReassociationExprs, 4> reassociationMap;
1099  uint64_t expandInputRank =
1100  cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1101  reassociationMap.resize(expandInputRank);
1102 
1103  for (uint64_t i = 0; i < expandInputRank; i++) {
1104  int32_t dimToPush = i > axis ? i + 1 : i;
1105  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1106  }
1107 
1108  if (expandInputRank != 0) {
1109  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1110  reassociationMap[expandedDim].push_back(
1111  rewriter.getAffineDimExpr(expandedDim + 1));
1112  }
1113 
1114  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1115  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1116  // not have access to such information. This matters when handling dynamically
1117  // sized tensors.
1118  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1119  op, resultTy, linalgOp.getResults()[0], reassociationMap);
1120  return success();
1121 }
1122 
1123 namespace {
1124 
1125 template <typename SrcOp>
1126 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1127 public:
1130 
1131  LogicalResult
1132  matchAndRewrite(SrcOp op, OpAdaptor operands,
1133  ConversionPatternRewriter &rewriter) const final {
1135  op, operands.getOperands(), rewriter, *this->getTypeConverter());
1136  }
1137 };
1138 
1139 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1140 public:
1142 
1143  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1144  PatternRewriter &rewriter) const final {
1145  auto loc = op.getLoc();
1146  auto input = op.getInput();
1147  auto inputTy = cast<ShapedType>(op.getInput().getType());
1148  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1149  unsigned rank = inputTy.getRank();
1150 
1151  // This is an illegal configuration. terminate and log an error
1152  if (op.getDoubleRound() && !op.getScale32())
1153  return rewriter.notifyMatchFailure(
1154  op, "tosa.rescale requires scale32 for double_round to be true");
1155 
1156  if (!isa<IntegerType>(inputTy.getElementType()))
1157  return rewriter.notifyMatchFailure(op, "only support integer type");
1158 
1159  SmallVector<Value> dynDims;
1160  for (int i = 0; i < outputTy.getRank(); i++) {
1161  if (outputTy.isDynamicDim(i)) {
1162  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1163  }
1164  }
1165 
1166  // The shift and multiplier values.
1167  SmallVector<int32_t> multiplierValues(op.getMultiplier());
1168  SmallVector<int8_t> shiftValues(op.getShift());
1169 
1170  // If we shift by more than the bitwidth, this just sets to 0.
1171  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1172  if (shiftValues[i] > 63) {
1173  shiftValues[i] = 0;
1174  multiplierValues[i] = 0;
1175  }
1176  }
1177 
1178  // Double round only occurs if shift is greater than 31, check that this
1179  // is ever true.
1180  bool doubleRound =
1181  op.getDoubleRound() &&
1182  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1183 
1184  SmallVector<AffineMap> indexingMaps = {
1185  rewriter.getMultiDimIdentityMap(rank)};
1186  SmallVector<Value, 4> genericInputs = {input};
1187 
1188  // If we are rescaling per-channel then we need to store the multiplier
1189  // values in a buffer.
1190  Value multiplierConstant;
1191  int64_t multiplierArg = 0;
1192  if (multiplierValues.size() == 1) {
1193  multiplierConstant = rewriter.create<arith::ConstantOp>(
1194  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1195  } else {
1196  SmallVector<AffineExpr, 2> multiplierExprs{
1197  rewriter.getAffineDimExpr(rank - 1)};
1198  auto multiplierType =
1199  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1200  rewriter.getI32Type());
1201  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1202  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1203 
1204  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1205  /*symbolCount=*/0, multiplierExprs,
1206  rewriter.getContext()));
1207 
1208  multiplierArg = indexingMaps.size() - 1;
1209  }
1210 
1211  // If we are rescaling per-channel then we need to store the shift
1212  // values in a buffer.
1213  Value shiftConstant;
1214  int64_t shiftArg = 0;
1215  if (shiftValues.size() == 1) {
1216  shiftConstant = rewriter.create<arith::ConstantOp>(
1217  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1218  } else {
1219  SmallVector<AffineExpr, 2> shiftExprs = {
1220  rewriter.getAffineDimExpr(rank - 1)};
1221  auto shiftType =
1222  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1223  rewriter.getIntegerType(8));
1224  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1225  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1226  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1227  /*symbolCount=*/0, shiftExprs,
1228  rewriter.getContext()));
1229  shiftArg = indexingMaps.size() - 1;
1230  }
1231 
1232  // Indexing maps for output values.
1233  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1234 
1235  // Construct the indexing maps needed for linalg.generic ops.
1236  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1237  loc, outputTy.getShape(), outputTy.getElementType(),
1238  ArrayRef<Value>({dynDims}));
1239 
1240  auto linalgOp = rewriter.create<linalg::GenericOp>(
1241  loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1242  getNParallelLoopsAttrs(rank),
1243  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1244  ValueRange blockArgs) {
1245  Value value = blockArgs[0];
1246  Type valueTy = value.getType();
1247 
1248  // For now we do all of our math in 64-bit. This is not optimal but
1249  // should be correct for now, consider computing correct bit depth
1250  // later.
1251  int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1252 
1253  auto inputZp = createConstFromIntAttribute<int32_t>(
1254  op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1255  nestedBuilder);
1256  auto outputZp = createConstFromIntAttribute<int32_t>(
1257  op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1258 
1259  Value multiplier = multiplierConstant ? multiplierConstant
1260  : blockArgs[multiplierArg];
1261  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1262 
1263  if (valueTy.getIntOrFloatBitWidth() < 32) {
1264  if (valueTy.isUnsignedInteger()) {
1265  value = nestedBuilder
1266  .create<UnrealizedConversionCastOp>(
1267  nestedLoc,
1268  nestedBuilder.getIntegerType(
1269  valueTy.getIntOrFloatBitWidth()),
1270  value)
1271  .getResult(0);
1272  value = nestedBuilder.create<arith::ExtUIOp>(
1273  nestedLoc, nestedBuilder.getI32Type(), value);
1274  } else {
1275  value = nestedBuilder.create<arith::ExtSIOp>(
1276  nestedLoc, nestedBuilder.getI32Type(), value);
1277  }
1278  }
1279 
1280  value =
1281  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1282 
1283  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1284  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1285  nestedBuilder.getBoolAttr(doubleRound));
1286 
1287  // Move to the new zero-point.
1288  value =
1289  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1290 
1291  // Saturate to the output size.
1292  IntegerType outIntType =
1293  cast<IntegerType>(blockArgs.back().getType());
1294  unsigned outBitWidth = outIntType.getWidth();
1295 
1296  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1297  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1298 
1299  // Unsigned integers have a difference output value.
1300  if (outIntType.isUnsignedInteger()) {
1301  intMin = 0;
1302  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1303  }
1304 
1305  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1306  loc, nestedBuilder.getI32IntegerAttr(intMin));
1307  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1308  loc, nestedBuilder.getI32IntegerAttr(intMax));
1309 
1310  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1311  nestedBuilder, /*isUnsigned=*/false);
1312 
1313  if (outIntType.getWidth() < 32) {
1314  value = nestedBuilder.create<arith::TruncIOp>(
1315  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1316  value);
1317 
1318  if (outIntType.isUnsignedInteger()) {
1319  value = nestedBuilder
1320  .create<UnrealizedConversionCastOp>(nestedLoc,
1321  outIntType, value)
1322  .getResult(0);
1323  }
1324  }
1325 
1326  nestedBuilder.create<linalg::YieldOp>(loc, value);
1327  });
1328 
1329  rewriter.replaceOp(op, linalgOp->getResults());
1330  return success();
1331  }
1332 };
1333 
1334 // Handle the resize case where the input is a 1x1 image. This case
1335 // can entirely avoiding having extract operations which target much
1336 // more difficult to optimize away.
1337 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1338 public:
1340 
1341  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1342  PatternRewriter &rewriter) const final {
1343  Location loc = op.getLoc();
1344  ImplicitLocOpBuilder builder(loc, rewriter);
1345  auto input = op.getInput();
1346  auto inputTy = cast<RankedTensorType>(input.getType());
1347  auto resultTy = cast<RankedTensorType>(op.getType());
1348  const bool isBilinear = op.getMode() == "BILINEAR";
1349 
1350  auto inputH = inputTy.getDimSize(1);
1351  auto inputW = inputTy.getDimSize(2);
1352  auto outputH = resultTy.getDimSize(1);
1353  auto outputW = resultTy.getDimSize(2);
1354 
1355  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1356  return rewriter.notifyMatchFailure(
1357  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1358 
1359  // TODO(suderman): These string values should be declared the TOSA dialect.
1360  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1361  return rewriter.notifyMatchFailure(
1362  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1363 
1364  if (inputTy == resultTy) {
1365  rewriter.replaceOp(op, input);
1366  return success();
1367  }
1368 
1369  ArrayRef<int64_t> scale = op.getScale();
1370 
1371  // Collapse the unit width and height away.
1372  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1373  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1374  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1375  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1376  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1377 
1378  auto collapseTy =
1379  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1380  inputTy.getElementType());
1381  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1382  reassociationMap);
1383 
1384  // Get any dynamic shapes that appear in the input format.
1385  llvm::SmallVector<Value> outputDynSize;
1386  if (inputTy.isDynamicDim(0))
1387  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1388  if (inputTy.isDynamicDim(3))
1389  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1390 
1391  // Generate the elementwise operation for casting scaling the input value.
1392  auto genericTy = collapseTy.clone(resultTy.getElementType());
1393  Value empty = builder.create<tensor::EmptyOp>(
1394  genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1395  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1396  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1397  utils::IteratorType::parallel);
1398 
1399  auto generic = builder.create<linalg::GenericOp>(
1400  genericTy, ValueRange{collapse}, ValueRange{empty},
1401  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1402  [=](OpBuilder &b, Location loc, ValueRange args) {
1403  Value value = args[0];
1404  // This is the quantized case.
1405  if (inputTy.getElementType() != resultTy.getElementType()) {
1406  value =
1407  b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1408 
1409  if (isBilinear && scale[0] != 0) {
1410  Value scaleY = b.create<arith::ConstantOp>(
1411  loc, b.getI32IntegerAttr(scale[0]));
1412  value = b.create<arith::MulIOp>(loc, value, scaleY);
1413  }
1414 
1415  if (isBilinear && scale[2] != 0) {
1416  Value scaleX = b.create<arith::ConstantOp>(
1417  loc, b.getI32IntegerAttr(scale[2]));
1418  value = b.create<arith::MulIOp>(loc, value, scaleX);
1419  }
1420  }
1421 
1422  b.create<linalg::YieldOp>(loc, value);
1423  });
1424 
1425  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1426  op, resultTy, generic.getResults()[0], reassociationMap);
1427  return success();
1428  }
1429 };
1430 
1431 // TOSA resize with width or height of 1 may be broadcasted to a wider
1432 // dimension. This is done by materializing a new tosa.resize without
1433 // the broadcasting behavior, and an explicit broadcast afterwards.
1434 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1435 public:
1437 
1438  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1439  PatternRewriter &rewriter) const final {
1440  Location loc = op.getLoc();
1441  ImplicitLocOpBuilder builder(loc, rewriter);
1442  auto input = op.getInput();
1443  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1444  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1445 
1446  if (!inputTy || !resultTy)
1447  return rewriter.notifyMatchFailure(op,
1448  "requires ranked input/output types");
1449 
1450  auto batch = inputTy.getDimSize(0);
1451  auto channels = inputTy.getDimSize(3);
1452  auto inputH = inputTy.getDimSize(1);
1453  auto inputW = inputTy.getDimSize(2);
1454  auto outputH = resultTy.getDimSize(1);
1455  auto outputW = resultTy.getDimSize(2);
1456 
1457  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1458  return rewriter.notifyMatchFailure(
1459  op, "tosa.resize has no broadcasting behavior");
1460 
1461  // For any dimension that is broadcastable we generate a width of 1
1462  // on the output.
1463  llvm::SmallVector<int64_t> resizeShape;
1464  resizeShape.push_back(batch);
1465  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1466  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1467  resizeShape.push_back(channels);
1468 
1469  auto resizeTy = resultTy.clone(resizeShape);
1470  auto resize =
1471  builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1472 
1473  // Collapse an unit result dims.
1474  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1475  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1476  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1477  if (inputH != 1)
1478  reassociationMap.push_back({});
1479  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1480  if (inputW != 1)
1481  reassociationMap.push_back({});
1482  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1483 
1484  llvm::SmallVector<int64_t> collapseShape{batch};
1485  if (inputH != 1)
1486  collapseShape.push_back(outputH);
1487  if (inputW != 1)
1488  collapseShape.push_back(outputW);
1489  collapseShape.push_back(channels);
1490 
1491  auto collapseTy = resultTy.clone(collapseShape);
1492  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1493  reassociationMap);
1494 
1495  // Broadcast the collapsed shape to the output result.
1496  llvm::SmallVector<Value> outputDynSize;
1497  if (inputTy.isDynamicDim(0))
1498  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1499  if (inputTy.isDynamicDim(3))
1500  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1501 
1502  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1503  utils::IteratorType::parallel);
1504  Value empty = builder.create<tensor::EmptyOp>(
1505  resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1506 
1507  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1508  if (inputH != 1)
1509  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1510  if (inputW != 1)
1511  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1512  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1513 
1514  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1515  inputExprs, rewriter.getContext());
1516 
1517  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1518  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1519  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1520  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1521  [=](OpBuilder &b, Location loc, ValueRange args) {
1522  Value value = args[0];
1523  b.create<linalg::YieldOp>(loc, value);
1524  });
1525 
1526  return success();
1527  }
1528 };
1529 
1530 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1531 public:
1533 
1534  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1535  PatternRewriter &rewriter) const final {
1536  Location loc = op.getLoc();
1537  ImplicitLocOpBuilder b(loc, rewriter);
1538  auto input = op.getInput();
1539  auto inputTy = cast<ShapedType>(input.getType());
1540  auto resultTy = cast<ShapedType>(op.getType());
1541  auto resultETy = resultTy.getElementType();
1542 
1543  bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1544  auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1545 
1546  auto imageH = inputTy.getShape()[1];
1547  auto imageW = inputTy.getShape()[2];
1548 
1549  auto dynamicDimsOr =
1550  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1551  if (!dynamicDimsOr.has_value())
1552  return rewriter.notifyMatchFailure(
1553  op, "unable to get dynamic dimensions of tosa.resize");
1554 
1555  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1556  return rewriter.notifyMatchFailure(
1557  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1558 
1559  SmallVector<AffineMap, 2> affineMaps = {
1560  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1561  auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1562  *dynamicDimsOr);
1563  auto genericOp = b.create<linalg::GenericOp>(
1564  resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1565  getNParallelLoopsAttrs(resultTy.getRank()));
1566  Value resize = genericOp.getResult(0);
1567 
1568  {
1569  OpBuilder::InsertionGuard regionGuard(b);
1570  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1571  TypeRange({resultETy}), loc);
1572  Value batch = b.create<linalg::IndexOp>(0);
1573  Value y = b.create<linalg::IndexOp>(1);
1574  Value x = b.create<linalg::IndexOp>(2);
1575  Value channel = b.create<linalg::IndexOp>(3);
1576 
1577  Value zeroI32 =
1578  b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1579  Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1580  Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1581  Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1582 
1583  Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1584  Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1585 
1586  ArrayRef<int64_t> offset = op.getOffset();
1587  ArrayRef<int64_t> border = op.getBorder();
1588  ArrayRef<int64_t> scale = op.getScale();
1589 
1590  Value yScaleN, yScaleD, xScaleN, xScaleD;
1591  yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1592  yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1593  xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1594  xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1595 
1596  Value yOffset, xOffset, yBorder, xBorder;
1597  yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1598  xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1599  yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1600  xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1601 
1602  // Compute the ix and dx values for both the X and Y dimensions.
1603  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1604  Value scaleN, Value scaleD, Value offset,
1605  int size, ImplicitLocOpBuilder &b) {
1606  if (size == 1) {
1607  index = zeroI32;
1608  delta = zeroFp;
1609  return;
1610  }
1611  // x = x * scale_d + offset;
1612  // ix = floor(x / scale_n)
1613  Value val = b.create<arith::MulIOp>(in, scaleD);
1614  val = b.create<arith::AddIOp>(val, offset);
1615  index = b.create<arith::FloorDivSIOp>(val, scaleN);
1616 
1617  // rx = x % scale_n
1618  // dx = rx / scale_n
1619  Value r = b.create<arith::RemSIOp>(val, scaleN);
1620  Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1621  Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1622  delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1623  };
1624 
1625  // Compute the ix and dx values for the X and Y dimensions - int case.
1626  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1627  Value scaleN, Value scaleD, Value offset,
1628  int size, ImplicitLocOpBuilder &b) {
1629  if (size == 1) {
1630  index = zeroI32;
1631  delta = zeroI32;
1632  return;
1633  }
1634  // x = x * scale_d + offset;
1635  // ix = floor(x / scale_n)
1636  // dx = x - ix * scale_n;
1637  Value val = b.create<arith::MulIOp>(in, scaleD);
1638  val = b.create<arith::AddIOp>(val, offset);
1639  index = b.create<arith::DivSIOp>(val, scaleN);
1640  delta = b.create<arith::MulIOp>(index, scaleN);
1641  delta = b.create<arith::SubIOp>(val, delta);
1642  };
1643 
1644  Value ix, iy, dx, dy;
1645  if (floatingPointMode) {
1646  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1647  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1648  } else {
1649  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1650  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1651  }
1652 
1653  if (op.getMode() == "NEAREST_NEIGHBOR") {
1654  auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1655 
1656  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1657  Value max, int size,
1658  ImplicitLocOpBuilder &b) -> Value {
1659  if (size == 1) {
1660  return b.create<arith::ConstantIndexOp>(0);
1661  }
1662 
1663  Value pred;
1664  if (floatingPointMode) {
1665  auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1666  pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1667  } else {
1668  Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1669  pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1670  dvalDouble, scale);
1671  }
1672 
1673  auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1674  val = b.create<arith::AddIOp>(val, offset);
1675  val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
1676  return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1677  };
1678 
1679  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1680  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1681 
1682  Value result = b.create<tensor::ExtractOp>(
1683  input, ValueRange{batch, iy, ix, channel});
1684 
1685  b.create<linalg::YieldOp>(result);
1686  } else {
1687  // The mode here must be BILINEAR.
1688  assert(op.getMode() == "BILINEAR");
1689 
1690  auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1691 
1692  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1694  val0 = in;
1695  val1 = b.create<arith::AddIOp>(val0, oneVal);
1696  val0 =
1697  clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1698  val1 =
1699  clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
1700  val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1701  val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1702  };
1703 
1704  // Linalg equivalent to the section below:
1705  // int16_t iy0 = apply_max(iy, 0);
1706  // int16_t iy1 = apply_min(iy + 1, IH - 1);
1707  // int16_t ix0 = apply_max(ix, 0);
1708  // int16_t ix1 = apply_min(ix + 1, IW - 1);
1709  Value x0, x1, y0, y1;
1710  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1711  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1712 
1713  Value y0x0 = b.create<tensor::ExtractOp>(
1714  input, ValueRange{batch, y0, x0, channel});
1715  Value y0x1 = b.create<tensor::ExtractOp>(
1716  input, ValueRange{batch, y0, x1, channel});
1717  Value y1x0 = b.create<tensor::ExtractOp>(
1718  input, ValueRange{batch, y1, x0, channel});
1719  Value y1x1 = b.create<tensor::ExtractOp>(
1720  input, ValueRange{batch, y1, x1, channel});
1721 
1722  if (floatingPointMode) {
1723  auto oneVal =
1724  b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1725  auto interpolate = [&](Value val0, Value val1, Value delta,
1726  int inputSize,
1727  ImplicitLocOpBuilder &b) -> Value {
1728  if (inputSize == 1)
1729  return val0;
1730  Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1731  Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1732  Value mul1 = b.create<arith::MulFOp>(val1, delta);
1733  return b.create<arith::AddFOp>(mul0, mul1);
1734  };
1735 
1736  // Linalg equivalent to the section below:
1737  // topAcc = v00 * (unit_x - dx);
1738  // topAcc += v01 * dx;
1739  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1740 
1741  // Linalg equivalent to the section below:
1742  // bottomAcc = v10 * (unit_x - dx);
1743  // bottomAcc += v11 * dx;
1744  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1745 
1746  // Linalg equivalent to the section below:
1747  // result = topAcc * (unit_y - dy) + bottomAcc * dy
1748  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1749  b.create<linalg::YieldOp>(result);
1750  } else {
1751  // Perform in quantized space.
1752  y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1753  y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1754  y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1755  y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1756 
1757  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1758  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1759  dx = b.create<arith::ExtSIOp>(resultETy, dx);
1760  dy = b.create<arith::ExtSIOp>(resultETy, dy);
1761  }
1762 
1763  Value yScaleNExt = yScaleN;
1764  Value xScaleNExt = xScaleN;
1765 
1766  const int64_t scaleBitwidth =
1767  xScaleN.getType().getIntOrFloatBitWidth();
1768  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1769  yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1770  xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1771  }
1772 
1773  auto interpolate = [](Value val0, Value val1, Value weight1,
1774  Value scale, int inputSize,
1775  ImplicitLocOpBuilder &b) -> Value {
1776  if (inputSize == 1)
1777  return b.create<arith::MulIOp>(val0, scale);
1778  Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1779  Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1780  Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1781  return b.create<arith::AddIOp>(mul0, mul1);
1782  };
1783 
1784  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1785  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1786  Value result =
1787  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1788  b.create<linalg::YieldOp>(result);
1789  }
1790  }
1791  }
1792 
1793  rewriter.replaceOp(op, resize);
1794  return success();
1795  }
1796 };
1797 
1798 // At the codegen level any identity operations should be removed. Any cases
1799 // where identity is load-bearing (e.g. cross device computation) should be
1800 // handled before lowering to codegen.
1801 template <typename SrcOp>
1802 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1803 public:
1805 
1806  LogicalResult matchAndRewrite(SrcOp op,
1807  PatternRewriter &rewriter) const final {
1808  rewriter.replaceOp(op, op.getOperation()->getOperands());
1809  return success();
1810  }
1811 };
1812 
1813 template <typename SrcOp>
1814 class ReduceConverter : public OpRewritePattern<SrcOp> {
1815 public:
1817 
1818  LogicalResult matchAndRewrite(SrcOp reduceOp,
1819  PatternRewriter &rewriter) const final {
1820  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1821  }
1822 };
1823 
1824 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1825 public:
1827 
1828  LogicalResult matchAndRewrite(tosa::ReverseOp op,
1829  PatternRewriter &rewriter) const final {
1830  auto loc = op.getLoc();
1831  Value input = op.getInput1();
1832  auto inputTy = cast<ShapedType>(input.getType());
1833  auto resultTy = cast<ShapedType>(op.getType());
1834  auto axis = op.getAxis();
1835 
1836  SmallVector<Value> dynDims;
1837  for (int i = 0; i < inputTy.getRank(); i++) {
1838  if (inputTy.isDynamicDim(i)) {
1839  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1840  }
1841  }
1842 
1843  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1844 
1845  // First fill the output buffer with the init value.
1846  auto emptyTensor = rewriter
1847  .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1848  inputTy.getElementType(),
1849  ArrayRef<Value>({dynDims}))
1850  .getResult();
1851  SmallVector<AffineMap, 2> affineMaps = {
1852  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1853 
1854  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1855  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
1856  getNParallelLoopsAttrs(resultTy.getRank()),
1857  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1858  llvm::SmallVector<Value> indices;
1859  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1860  Value index =
1861  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1862  if (i == axis) {
1863  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1864  auto sizeMinusOne =
1865  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1866  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1867  index);
1868  }
1869 
1870  indices.push_back(index);
1871  }
1872 
1873  auto extract = nestedBuilder.create<tensor::ExtractOp>(
1874  nestedLoc, input, indices);
1875  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1876  extract.getResult());
1877  });
1878  return success();
1879  }
1880 };
1881 
1882 // This converter translate a tile operation to a reshape, broadcast, reshape.
1883 // The first reshape minimally expands each tiled dimension to include a
1884 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1885 // multiple.
1886 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1888 
1889  LogicalResult
1890  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1891  ConversionPatternRewriter &rewriter) const override {
1892  auto loc = op.getLoc();
1893  auto input = op.getInput1();
1894  auto inputTy = cast<ShapedType>(input.getType());
1895  auto inputShape = inputTy.getShape();
1896  auto resultTy = cast<ShapedType>(op.getType());
1897  auto elementTy = inputTy.getElementType();
1898  int64_t rank = inputTy.getRank();
1899 
1900  ArrayRef<int64_t> multiples = op.getMultiples();
1901 
1902  // Broadcast the newly added dimensions to their appropriate multiple.
1903  SmallVector<int64_t, 2> genericShape;
1904  for (int i = 0; i < rank; i++) {
1905  int64_t dim = multiples[i];
1906  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1907  genericShape.push_back(inputShape[i]);
1908  }
1909 
1910  SmallVector<Value> dynDims;
1911  for (int i = 0; i < inputTy.getRank(); i++) {
1912  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1913  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1914  }
1915  }
1916 
1917  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1918  op.getLoc(), genericShape, elementTy, dynDims);
1919 
1920  // We needs to map the input shape to the non-broadcasted dimensions.
1921  SmallVector<AffineExpr, 4> dimExprs;
1922  dimExprs.reserve(rank);
1923  for (unsigned i = 0; i < rank; ++i)
1924  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1925 
1926  auto readAffineMap =
1927  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1928  rewriter.getContext());
1929 
1930  SmallVector<AffineMap, 2> affineMaps = {
1931  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1932 
1933  auto genericOp = rewriter.create<linalg::GenericOp>(
1934  loc, RankedTensorType::get(genericShape, elementTy), input,
1935  ValueRange{emptyTensor}, affineMaps,
1936  getNParallelLoopsAttrs(genericShape.size()),
1937  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1938  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1939  });
1940 
1941  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1942  op, resultTy, genericOp.getResult(0),
1943  rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
1944  return success();
1945  }
1946 };
1947 
1948 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1949 // op, producing two output buffers.
1950 //
1951 // The first output buffer contains the index of the found maximum value. It is
1952 // initialized to 0 and is resulting integer type.
1953 //
1954 // The second output buffer contains the maximum value found. It is initialized
1955 // to the minimum representable value of the input element type. After being
1956 // populated by indexed_generic, this buffer is disgarded as only the index is
1957 // requested.
1958 //
1959 // The indexed_generic op updates both the maximum value and index if the
1960 // current value exceeds the running max.
1961 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1962 public:
1964 
1965  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1966  PatternRewriter &rewriter) const final {
1967  auto loc = argmaxOp.getLoc();
1968  Value input = argmaxOp.getInput();
1969  auto inputTy = cast<ShapedType>(input.getType());
1970  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1971  auto inElementTy = inputTy.getElementType();
1972  auto outElementTy = resultTy.getElementType();
1973  int axis = argmaxOp.getAxis();
1974  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1975 
1976  if (!isa<IntegerType>(outElementTy))
1977  return rewriter.notifyMatchFailure(
1978  argmaxOp,
1979  "tosa.arg_max to linalg.* requires integer-like result type");
1980 
1981  SmallVector<Value> dynDims;
1982  for (int i = 0; i < inputTy.getRank(); i++) {
1983  if (inputTy.isDynamicDim(i) && i != axis) {
1984  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1985  }
1986  }
1987 
1988  // First fill the output buffer for the index.
1989  auto emptyTensorIdx = rewriter
1990  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1991  outElementTy, dynDims)
1992  .getResult();
1993  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1994  loc, rewriter.getIntegerAttr(outElementTy, 0));
1995  auto filledTensorIdx =
1996  rewriter
1997  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
1998  ValueRange{emptyTensorIdx})
1999  .result();
2000 
2001  // Second fill the output buffer for the running max.
2002  auto emptyTensorMax = rewriter
2003  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2004  inElementTy, dynDims)
2005  .getResult();
2006  auto fillValueMaxAttr =
2007  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2008 
2009  if (!fillValueMaxAttr)
2010  return rewriter.notifyMatchFailure(
2011  argmaxOp, "unsupported tosa.argmax element type");
2012 
2013  auto fillValueMax =
2014  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2015  auto filledTensorMax =
2016  rewriter
2017  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2018  ValueRange{emptyTensorMax})
2019  .result();
2020 
2021  // We need to reduce along the arg-max axis, with parallel operations along
2022  // the rest.
2024  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2025  iteratorTypes[axis] = utils::IteratorType::reduction;
2026 
2027  SmallVector<AffineExpr, 2> srcExprs;
2028  SmallVector<AffineExpr, 2> dstExprs;
2029  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2030  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2031  if (axis != i)
2032  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2033  }
2034 
2035  bool didEncounterError = false;
2036  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2037  rewriter.getContext());
2038  auto linalgOp = rewriter.create<linalg::GenericOp>(
2039  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2040  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2041  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2042  ValueRange blockArgs) {
2043  auto newValue = blockArgs[0];
2044  auto oldIndex = blockArgs[1];
2045  auto oldValue = blockArgs[2];
2046 
2047  Value newIndex = rewriter.create<arith::IndexCastOp>(
2048  nestedLoc, oldIndex.getType(),
2049  rewriter.create<linalg::IndexOp>(loc, axis));
2050 
2051  Value predicate;
2052  if (isa<FloatType>(inElementTy)) {
2053  predicate = rewriter.create<arith::CmpFOp>(
2054  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2055  } else if (isa<IntegerType>(inElementTy)) {
2056  predicate = rewriter.create<arith::CmpIOp>(
2057  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2058  } else {
2059  didEncounterError = true;
2060  return;
2061  }
2062 
2063  auto resultMax = rewriter.create<arith::SelectOp>(
2064  nestedLoc, predicate, newValue, oldValue);
2065  auto resultIndex = rewriter.create<arith::SelectOp>(
2066  nestedLoc, predicate, newIndex, oldIndex);
2067  nestedBuilder.create<linalg::YieldOp>(
2068  nestedLoc, ValueRange({resultIndex, resultMax}));
2069  });
2070 
2071  if (didEncounterError)
2072  return rewriter.notifyMatchFailure(
2073  argmaxOp, "unsupported tosa.argmax element type");
2074 
2075  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2076  return success();
2077  }
2078 };
2079 
2080 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2081 public:
2083  LogicalResult
2084  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2085  ConversionPatternRewriter &rewriter) const final {
2086  auto input = adaptor.getOperands()[0];
2087  auto indices = adaptor.getOperands()[1];
2088 
2089  auto valuesTy =
2090  dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2091  auto resultTy = cast<ShapedType>(op.getType());
2092 
2093  if (!valuesTy)
2094  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2095 
2096  auto dynamicDims = inferDynamicDimsForGather(
2097  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2098 
2099  auto resultElementTy = resultTy.getElementType();
2100 
2101  auto loc = op.getLoc();
2102  auto emptyTensor =
2103  rewriter
2104  .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2105  dynamicDims)
2106  .getResult();
2107 
2108  SmallVector<AffineMap, 2> affineMaps = {
2110  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2111  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2112  rewriter.getContext()),
2113  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2114 
2115  auto genericOp = rewriter.create<linalg::GenericOp>(
2116  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2117  ValueRange{emptyTensor}, affineMaps,
2118  getNParallelLoopsAttrs(resultTy.getRank()),
2119  [&](OpBuilder &b, Location loc, ValueRange args) {
2120  auto indexValue = args[0];
2121  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2122  Value index1 = rewriter.create<arith::IndexCastOp>(
2123  loc, rewriter.getIndexType(), indexValue);
2124  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2125  Value extract = rewriter.create<tensor::ExtractOp>(
2126  loc, input, ValueRange{index0, index1, index2});
2127  rewriter.create<linalg::YieldOp>(loc, extract);
2128  });
2129  rewriter.replaceOp(op, genericOp.getResult(0));
2130  return success();
2131  }
2132 
2133  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2134  Location loc,
2135  Value values,
2136  Value indices) {
2137  llvm::SmallVector<Value> results;
2138 
2139  auto addDynamicDimension = [&](Value source, int64_t dim) {
2140  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2141  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2142  results.push_back(dimValue);
2143  };
2144 
2145  addDynamicDimension(values, 0);
2146  addDynamicDimension(indices, 1);
2147  addDynamicDimension(values, 2);
2148  return results;
2149  }
2150 };
2151 
2152 // Lowerings the TableOp to a series of gathers and numerica operations. This
2153 // includes interpolation between the high/low values. For the I8 varient, this
2154 // simplifies to a single gather operation.
2155 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2156 public:
2158 
2159  LogicalResult matchAndRewrite(tosa::TableOp op,
2160  PatternRewriter &rewriter) const final {
2161  auto loc = op.getLoc();
2162  Value input = op.getInput1();
2163  Value table = op.getTable();
2164  auto inputTy = cast<ShapedType>(input.getType());
2165  auto tableTy = cast<ShapedType>(table.getType());
2166  auto resultTy = cast<ShapedType>(op.getType());
2167 
2168  auto inputElementTy = inputTy.getElementType();
2169  auto tableElementTy = tableTy.getElementType();
2170  auto resultElementTy = resultTy.getElementType();
2171 
2172  SmallVector<Value> dynDims;
2173  for (int i = 0; i < resultTy.getRank(); ++i) {
2174  if (inputTy.isDynamicDim(i)) {
2175  dynDims.push_back(
2176  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2177  }
2178  }
2179 
2180  auto emptyTensor = rewriter
2181  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2182  resultElementTy, dynDims)
2183  .getResult();
2184 
2185  SmallVector<AffineMap, 2> affineMaps = {
2186  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2187  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2188 
2189  auto genericOp = rewriter.create<linalg::GenericOp>(
2190  loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2191  getNParallelLoopsAttrs(resultTy.getRank()));
2192  rewriter.replaceOp(op, genericOp.getResult(0));
2193 
2194  {
2195  OpBuilder::InsertionGuard regionGuard(rewriter);
2196  Block *block = rewriter.createBlock(
2197  &genericOp.getRegion(), genericOp.getRegion().end(),
2198  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2199 
2200  auto inputValue = block->getArgument(0);
2201  rewriter.setInsertionPointToStart(block);
2202  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2203  resultElementTy.isInteger(8)) {
2204  Value index = rewriter.create<arith::IndexCastOp>(
2205  loc, rewriter.getIndexType(), inputValue);
2206  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2207  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2208  index, offset);
2209  Value extract =
2210  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2211  rewriter.create<linalg::YieldOp>(loc, extract);
2212  return success();
2213  }
2214 
2215  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2216  resultElementTy.isInteger(32)) {
2217  Value extend = rewriter.create<arith::ExtSIOp>(
2218  loc, rewriter.getI32Type(), inputValue);
2219 
2220  auto offset = rewriter.create<arith::ConstantOp>(
2221  loc, rewriter.getI32IntegerAttr(32768));
2222  auto seven = rewriter.create<arith::ConstantOp>(
2223  loc, rewriter.getI32IntegerAttr(7));
2224  auto one = rewriter.create<arith::ConstantOp>(
2225  loc, rewriter.getI32IntegerAttr(1));
2226  auto b1111111 = rewriter.create<arith::ConstantOp>(
2227  loc, rewriter.getI32IntegerAttr(127));
2228 
2229  // Compute the index and fractional part from the input value:
2230  // value = value + 32768
2231  // index = value >> 7;
2232  // fraction = 0x01111111 & value
2233  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2234  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2235  Value fraction =
2236  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2237 
2238  // Extract the base and next values from the table.
2239  // base = (int32_t) table[index];
2240  // next = (int32_t) table[index + 1];
2241  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2242 
2243  index = rewriter.create<arith::IndexCastOp>(
2244  loc, rewriter.getIndexType(), index);
2245  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2246  loc, rewriter.getIndexType(), indexPlusOne);
2247 
2248  Value base =
2249  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2250  Value next = rewriter.create<tensor::ExtractOp>(
2251  loc, table, ValueRange{indexPlusOne});
2252 
2253  base =
2254  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2255  next =
2256  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2257 
2258  // Use the fractional part to interpolate between the input values:
2259  // result = (base << 7) + (next - base) * fraction
2260  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2261  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2262  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2263  Value result =
2264  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2265 
2266  rewriter.create<linalg::YieldOp>(loc, result);
2267 
2268  return success();
2269  }
2270  }
2271 
2272  return rewriter.notifyMatchFailure(
2273  op, "unable to create body for tosa.table op");
2274  }
2275 };
2276 
2277 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2279 
2280  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2281 
2282  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2283  OpFoldResult ofr) {
2284  auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2285  auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2286 
2287  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2288  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2289  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2290  return getAsOpFoldResult(plusOne);
2291  }
2292 
2293  static RankedTensorType
2294  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2295  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2296  // Get [N, H, W]
2297  auto dims = tensor::getMixedSizes(builder, loc, input);
2298 
2299  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2300  // output tensors.
2301  dims[2] = halfPlusOne(builder, loc, dims[2]);
2302 
2303  llvm::SmallVector<int64_t, 3> staticSizes;
2304  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2305 
2306  auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2307  return RankedTensorType::get(staticSizes, elementType);
2308  }
2309 
2310  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2311  RankedTensorType type,
2312  llvm::ArrayRef<Value> dynamicSizes) {
2313  auto emptyTensor =
2314  rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2315  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2316  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2317  auto filledTensor = rewriter
2318  .create<linalg::FillOp>(loc, ValueRange{fillValue},
2319  ValueRange{emptyTensor})
2320  .result();
2321  return filledTensor;
2322  }
2323 
2324  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2325  FloatType type, Value value) {
2326  auto integerVal = builder.create<arith::IndexCastUIOp>(
2327  loc,
2328  type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2329  : builder.getI32Type(),
2330  value);
2331 
2332  return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2333  }
2334 
2335  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2336  FloatType type, int64_t index) {
2337  auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2338  return castIndexToFloat(builder, loc, type, indexVal);
2339  }
2340 
2341  template <typename... Args>
2342  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2343  Args... args) {
2344  return {builder.getAffineDimExpr(args)...};
2345  }
2346 
2347  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2348  PatternRewriter &rewriter) const override {
2349  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2350  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2351  return rewriter.notifyMatchFailure(rfft2d,
2352  "only supports ranked tensors");
2353  }
2354 
2355  auto loc = rfft2d.getLoc();
2356  auto input = rfft2d.getInput();
2357  auto elementType =
2358  dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2359  if (!elementType)
2360  return rewriter.notifyMatchFailure(rfft2d,
2361  "only supports float element types");
2362 
2363  // Compute the output type and set of dynamic sizes
2364  llvm::SmallVector<Value> dynamicSizes;
2365  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2366 
2367  // Iterator types for the linalg.generic implementation
2369  utils::IteratorType::parallel, utils::IteratorType::parallel,
2370  utils::IteratorType::parallel, utils::IteratorType::reduction,
2371  utils::IteratorType::reduction};
2372 
2373  // Inputs/outputs to the linalg.generic implementation
2374  llvm::SmallVector<Value> genericOpInputs = {input};
2375  llvm::SmallVector<Value> genericOpOutputs = {
2376  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2377  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2378 
2379  // Indexing maps for input and output tensors
2380  auto indexingMaps = AffineMap::inferFromExprList(
2381  llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2382  affineDimsExpr(rewriter, 0, 1, 2),
2383  affineDimsExpr(rewriter, 0, 1, 2)},
2384  rewriter.getContext());
2385 
2386  // Width and height dimensions of the original input.
2387  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2388  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2389 
2390  // Constants and dimension sizes
2391  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2392  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2393  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2394  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2395 
2396  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2397  Value valReal = args[0];
2398  Value sumReal = args[1];
2399  Value sumImag = args[2];
2400 
2401  // Indices for angle computation
2402  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2403  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2404  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2405  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2406 
2407  // Calculating angle without integer parts of components as sin/cos are
2408  // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2409  // / W);
2410  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2411  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2412 
2413  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2414  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2415 
2416  auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2417  auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2418 
2419  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2420  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2421  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2422  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2423 
2424  // realComponent = valReal * cos(angle)
2425  // imagComponent = valReal * sin(angle)
2426  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2427  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2428  auto realComponent =
2429  builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2430  auto imagComponent =
2431  builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2432 
2433  // outReal = sumReal + realComponent
2434  // outImag = sumImag - imagComponent
2435  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2436  auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2437 
2438  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2439  };
2440 
2441  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2442  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2443  indexingMaps, iteratorTypes, buildBody);
2444 
2445  return success();
2446  }
2447 };
2448 
2449 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2451 
2452  LogicalResult matchAndRewrite(FFT2dOp fft2d,
2453  PatternRewriter &rewriter) const override {
2454  if (!llvm::all_of(fft2d->getOperandTypes(),
2455  RFFT2dConverter::isRankedTensor) ||
2456  !llvm::all_of(fft2d->getResultTypes(),
2457  RFFT2dConverter::isRankedTensor)) {
2458  return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2459  }
2460 
2461  Location loc = fft2d.getLoc();
2462  Value input_real = fft2d.getInputReal();
2463  Value input_imag = fft2d.getInputImag();
2464  BoolAttr inverse = fft2d.getInverseAttr();
2465 
2466  auto real_el_ty = cast<FloatType>(
2467  cast<ShapedType>(input_real.getType()).getElementType());
2468  [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2469  cast<ShapedType>(input_imag.getType()).getElementType());
2470 
2471  assert(real_el_ty == imag_el_ty);
2472 
2473  // Compute the output type and set of dynamic sizes
2474  SmallVector<Value> dynamicSizes;
2475 
2476  // Get [N, H, W]
2477  auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2478 
2479  SmallVector<int64_t, 3> staticSizes;
2480  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2481 
2482  auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2483 
2484  // Iterator types for the linalg.generic implementation
2485  SmallVector<utils::IteratorType, 5> iteratorTypes = {
2486  utils::IteratorType::parallel, utils::IteratorType::parallel,
2487  utils::IteratorType::parallel, utils::IteratorType::reduction,
2488  utils::IteratorType::reduction};
2489 
2490  // Inputs/outputs to the linalg.generic implementation
2491  SmallVector<Value> genericOpInputs = {input_real, input_imag};
2492  SmallVector<Value> genericOpOutputs = {
2493  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2494  dynamicSizes),
2495  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2496  dynamicSizes)};
2497 
2498  // Indexing maps for input and output tensors
2499  auto indexingMaps = AffineMap::inferFromExprList(
2500  ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2501  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2502  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2503  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2504  rewriter.getContext());
2505 
2506  // Width and height dimensions of the original input.
2507  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2508  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2509 
2510  // Constants and dimension sizes
2511  auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2512  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2513  Value constH =
2514  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2515  Value constW =
2516  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2517 
2518  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2519  Value valReal = args[0];
2520  Value valImag = args[1];
2521  Value sumReal = args[2];
2522  Value sumImag = args[3];
2523 
2524  // Indices for angle computation
2525  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2526  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2527  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2528  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2529 
2530  // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2531  // ox) % W ) / W);
2532  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2533  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2534 
2535  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2536  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2537 
2538  auto iyRemFloat =
2539  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2540  auto ixRemFloat =
2541  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2542 
2543  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2544  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2545 
2546  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2547  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2548 
2549  if (inverse.getValue()) {
2550  angle = builder.create<arith::MulFOp>(
2551  loc, angle,
2552  rewriter.create<arith::ConstantOp>(
2553  loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2554  }
2555 
2556  // realComponent = val_real * cos(a) + val_imag * sin(a);
2557  // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2558  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2559  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2560 
2561  auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2562  auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2563  auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2564 
2565  auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2566  auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2567 
2568  auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2569 
2570  // outReal = sumReal + realComponent
2571  // outImag = sumImag - imagComponent
2572  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2573  auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2574 
2575  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2576  };
2577 
2578  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2579  fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2580  indexingMaps, iteratorTypes, buildBody);
2581 
2582  return success();
2583  }
2584 };
2585 
2586 } // namespace
2587 
2589  const TypeConverter &converter, RewritePatternSet *patterns) {
2590 
2591  // We have multiple resize coverters to handle degenerate cases.
2592  patterns->add<GenericResizeConverter>(patterns->getContext(),
2593  /*benefit=*/100);
2594  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2595  /*benefit=*/200);
2596  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2597  /*benefit=*/300);
2598 
2599  patterns->add<
2600  // clang-format off
2601  PointwiseConverter<tosa::AddOp>,
2602  PointwiseConverter<tosa::SubOp>,
2603  PointwiseConverter<tosa::MulOp>,
2604  PointwiseConverter<tosa::IntDivOp>,
2605  PointwiseConverter<tosa::NegateOp>,
2606  PointwiseConverter<tosa::PowOp>,
2607  PointwiseConverter<tosa::ReciprocalOp>,
2608  PointwiseConverter<tosa::RsqrtOp>,
2609  PointwiseConverter<tosa::LogOp>,
2610  PointwiseConverter<tosa::ExpOp>,
2611  PointwiseConverter<tosa::AbsOp>,
2612  PointwiseConverter<tosa::SinOp>,
2613  PointwiseConverter<tosa::CosOp>,
2614  PointwiseConverter<tosa::TanhOp>,
2615  PointwiseConverter<tosa::ErfOp>,
2616  PointwiseConverter<tosa::BitwiseAndOp>,
2617  PointwiseConverter<tosa::BitwiseOrOp>,
2618  PointwiseConverter<tosa::BitwiseNotOp>,
2619  PointwiseConverter<tosa::BitwiseXorOp>,
2620  PointwiseConverter<tosa::LogicalAndOp>,
2621  PointwiseConverter<tosa::LogicalNotOp>,
2622  PointwiseConverter<tosa::LogicalOrOp>,
2623  PointwiseConverter<tosa::LogicalXorOp>,
2624  PointwiseConverter<tosa::CastOp>,
2625  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2626  PointwiseConverter<tosa::LogicalRightShiftOp>,
2627  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2628  PointwiseConverter<tosa::ClzOp>,
2629  PointwiseConverter<tosa::SelectOp>,
2630  PointwiseConverter<tosa::GreaterOp>,
2631  PointwiseConverter<tosa::GreaterEqualOp>,
2632  PointwiseConverter<tosa::EqualOp>,
2633  PointwiseConverter<tosa::MaximumOp>,
2634  PointwiseConverter<tosa::MinimumOp>,
2635  PointwiseConverter<tosa::CeilOp>,
2636  PointwiseConverter<tosa::FloorOp>,
2637  PointwiseConverter<tosa::ClampOp>,
2638  PointwiseConverter<tosa::SigmoidOp>
2639  >(converter, patterns->getContext());
2640 
2641  patterns->add<
2642  IdentityNConverter<tosa::IdentityOp>,
2643  ReduceConverter<tosa::ReduceAllOp>,
2644  ReduceConverter<tosa::ReduceAnyOp>,
2645  ReduceConverter<tosa::ReduceMinOp>,
2646  ReduceConverter<tosa::ReduceMaxOp>,
2647  ReduceConverter<tosa::ReduceSumOp>,
2648  ReduceConverter<tosa::ReduceProdOp>,
2649  ArgMaxConverter,
2650  GatherConverter,
2651  RescaleConverter,
2652  ReverseConverter,
2653  RFFT2dConverter,
2654  FFT2dConverter,
2655  TableConverter,
2656  TileConverter>(patterns->getContext());
2657  // clang-format on
2658 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, ValueRange operands, int64_t rank)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
const float * table
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:99
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:56
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362