MLIR  20.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 
34 #include <numeric>
35 
36 using namespace mlir;
37 using namespace mlir::tosa;
38 
39 template <typename T>
40 static arith::ConstantOp
41 createConstFromIntAttribute(Operation *op, const std::string &attrName,
42  Type requiredAttrType, OpBuilder &rewriter) {
43  auto castedN = static_cast<T>(
44  cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
45  return rewriter.create<arith::ConstantOp>(
46  op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
47 }
48 
50  Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51  ConversionPatternRewriter &rewriter) {
52  Location loc = op->getLoc();
53  auto elementTy =
54  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
55 
56  // tosa::AbsOp
57  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
58  return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
59 
60  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
61  auto zero = rewriter.create<arith::ConstantOp>(
62  loc, rewriter.getZeroAttr(elementTy));
63  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
64  return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
65  }
66 
67  // tosa::AddOp
68  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
69  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
70 
71  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
72  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
73 
74  // tosa::SubOp
75  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
76  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
77 
78  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
79  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
80 
81  // tosa::MulOp
82  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
83  if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
84  (void)rewriter.notifyMatchFailure(op,
85  "Cannot have shift value for float");
86  return nullptr;
87  }
88  return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
89  }
90 
91  // tosa::IntDivOp
92  if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
93  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
94 
95  // tosa::ReciprocalOp
96  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
97  auto one =
98  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
99  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
100  }
101 
102  if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
103  Value a = args[0];
104  Value b = args[1];
105  auto shift =
106  cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
107  if (shift > 0) {
108  auto shiftConst =
109  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
110  if (!a.getType().isInteger(32))
111  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
112 
113  if (!b.getType().isInteger(32))
114  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
115 
116  auto result = rewriter.create<tosa::ApplyScaleOp>(
117  loc, rewriter.getI32Type(), a, b, shiftConst,
118  rewriter.getBoolAttr(false));
119 
120  if (elementTy.isInteger(32))
121  return result;
122 
123  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
124  }
125 
126  int aWidth = a.getType().getIntOrFloatBitWidth();
127  int bWidth = b.getType().getIntOrFloatBitWidth();
128  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
129 
130  if (aWidth < cWidth)
131  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
132  if (bWidth < cWidth)
133  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
134 
135  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
136  }
137 
138  // tosa::NegateOp
139  if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
140  return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
141 
142  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
143  !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
144  auto constant =
145  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
146  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
147  }
148 
149  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
150  cast<tosa::NegateOp>(op).getQuantizationInfo()) {
151  auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
152  int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
153  int64_t inZp = quantizationInfo.value().getInputZp();
154  int64_t outZp = quantizationInfo.value().getOutputZp();
155 
156  // Compute the maximum value that can occur in the intermediate buffer.
157  int64_t zpAdd = inZp + outZp;
158  int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
159  std::abs(zpAdd) + 1;
160 
161  // Convert that maximum value into the maximum bitwidth needed to represent
162  // it. We assume 48-bit numbers may be supported further in the pipeline.
163  int intermediateBitWidth = 64;
164  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
165  intermediateBitWidth = 16;
166  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
167  intermediateBitWidth = 32;
168  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
169  intermediateBitWidth = 48;
170  }
171 
172  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
173  Value zpAddValue = rewriter.create<arith::ConstantOp>(
174  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
175 
176  // The negation can be applied by doing:
177  // outputValue = inZp + outZp - inputValue
178  auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
179  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
180 
181  // Clamp to the negation range.
182  Value min = rewriter.create<arith::ConstantIntOp>(
183  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
184  intermediateType);
185  Value max = rewriter.create<arith::ConstantIntOp>(
186  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
187  intermediateType);
188  auto clamp =
189  clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
190 
191  // Truncate to the final value.
192  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
193  }
194 
195  // tosa::BitwiseAndOp
196  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
197  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
198 
199  // tosa::BitwiseOrOp
200  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
201  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
202 
203  // tosa::BitwiseNotOp
204  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
205  auto allOnesAttr = rewriter.getIntegerAttr(
206  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
207  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
208  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
209  }
210 
211  // tosa::BitwiseXOrOp
212  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
213  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
214 
215  // tosa::LogicalLeftShiftOp
216  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
217  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
218 
219  // tosa::LogicalRightShiftOp
220  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
221  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
222 
223  // tosa::ArithmeticRightShiftOp
224  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
225  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
226  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
227  if (!round) {
228  return result;
229  }
230 
231  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
232  auto one =
233  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
234  auto zero =
235  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
236  auto i1one =
237  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
238 
239  // Checking that input2 != 0
240  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
241  loc, arith::CmpIPredicate::sgt, args[1], zero);
242 
243  // Checking for the last bit of input1 to be 1
244  auto subtract =
245  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
246  auto shifted =
247  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
248  ->getResults();
249  auto truncated =
250  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
251  auto isInputOdd =
252  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
253 
254  auto shouldRound = rewriter.create<arith::AndIOp>(
255  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
256  auto extended =
257  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
258  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
259  }
260 
261  // tosa::ClzOp
262  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
263  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
264  }
265 
266  // tosa::LogicalAnd
267  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
268  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
269 
270  // tosa::LogicalNot
271  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
272  auto one = rewriter.create<arith::ConstantOp>(
273  loc, rewriter.getIntegerAttr(elementTy, 1));
274  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
275  }
276 
277  // tosa::LogicalOr
278  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
279  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
280 
281  // tosa::LogicalXor
282  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
283  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
284 
285  // tosa::PowOp
286  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
287  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
288 
289  // tosa::RsqrtOp
290  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
291  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
292 
293  // tosa::LogOp
294  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
295  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
296 
297  // tosa::ExpOp
298  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
299  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
300 
301  // tosa::SinOp
302  if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
303  return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
304 
305  // tosa::CosOp
306  if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
307  return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
308 
309  // tosa::TanhOp
310  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
311  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
312 
313  // tosa::ErfOp
314  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
315  return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
316 
317  // tosa::GreaterOp
318  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
319  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
320  args[0], args[1]);
321 
322  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
323  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
324  args[0], args[1]);
325 
326  // tosa::GreaterEqualOp
327  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
328  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
329  args[0], args[1]);
330 
331  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
332  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
333  args[0], args[1]);
334 
335  // tosa::EqualOp
336  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
337  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
338  args[0], args[1]);
339 
340  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
341  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
342  args[0], args[1]);
343 
344  // tosa::SelectOp
345  if (isa<tosa::SelectOp>(op)) {
346  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
347  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
348  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
349  }
350 
351  // tosa::MaximumOp
352  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
353  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
354  }
355 
356  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
357  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
358  }
359 
360  // tosa::MinimumOp
361  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
362  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
363  }
364 
365  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
366  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
367  }
368 
369  // tosa::CeilOp
370  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
371  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
372 
373  // tosa::FloorOp
374  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
375  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
376 
377  // tosa::ClampOp
378  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
379  bool losesInfo = false;
380  APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
381  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
382  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
383  APFloat::rmNearestTiesToEven, &losesInfo);
384  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
385  APFloat::rmNearestTiesToEven, &losesInfo);
386  auto min = rewriter.create<arith::ConstantOp>(
387  loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
388  auto max = rewriter.create<arith::ConstantOp>(
389  loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
390  return clampFloatHelper(loc, args[0], min, max, rewriter);
391  }
392 
393  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
394  auto intTy = cast<IntegerType>(elementTy);
395  int64_t min =
396  cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
397  int64_t max =
398  cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
399 
400  int64_t minRepresentable = std::numeric_limits<int64_t>::min();
401  int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
402  if (intTy.isUnsignedInteger()) {
403  minRepresentable = 0;
404  if (intTy.getIntOrFloatBitWidth() <= 63) {
405  maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
406  .getZExtValue();
407  }
408  } else if(intTy.getIntOrFloatBitWidth() <= 64) {
409  // Ensure that min & max fit into signed n-bit constants.
410  minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
411  .getSExtValue();
412  maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
413  .getSExtValue();
414  }
415  // Ensure that the bounds are representable as n-bit signed/unsigned integers.
416  min = std::max(min, minRepresentable);
417  max = std::max(max, minRepresentable);
418  min = std::min(min, maxRepresentable);
419  max = std::min(max, maxRepresentable);
420 
421  auto minVal = rewriter.create<arith::ConstantIntOp>(
422  loc, min, intTy.getIntOrFloatBitWidth());
423  auto maxVal = rewriter.create<arith::ConstantIntOp>(
424  loc, max, intTy.getIntOrFloatBitWidth());
425  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
426  intTy.isUnsignedInteger());
427  }
428 
429  // tosa::SigmoidOp
430  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
431  auto one =
432  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
433  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
434  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
435  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
436  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
437  }
438 
439  // tosa::CastOp
440  if (isa<tosa::CastOp>(op)) {
441  Type srcTy = elementTy;
442  Type dstTy = resultTypes.front();
443  bool bitExtend =
445 
446  if (srcTy == dstTy)
447  return args.front();
448 
449  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
450  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
451  std::nullopt);
452 
453  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
454  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
455  std::nullopt);
456 
457  // 1-bit integers need to be treated as signless.
458  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
459  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
460  std::nullopt);
461 
462  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
463  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
464  std::nullopt);
465 
466  // Unsigned integers need an unrealized cast so that they can be passed
467  // to UIToFP.
468  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
469  auto unrealizedCast =
470  rewriter
471  .create<UnrealizedConversionCastOp>(
472  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
473  args[0])
474  .getResult(0);
475  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
476  unrealizedCast);
477  }
478 
479  // All other si-to-fp conversions should be handled by SIToFP.
480  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
481  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
482  std::nullopt);
483 
484  // Casting to boolean, floats need to only be checked as not-equal to zero.
485  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
486  Value zero = rewriter.create<arith::ConstantOp>(
487  loc, rewriter.getFloatAttr(srcTy, 0.0));
488  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
489  args.front(), zero);
490  }
491 
492  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
493  auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
494 
495  const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
496  // Check whether neither int min nor int max can be represented in the
497  // input floating-point type due to too short exponent range.
498  if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
499  APFloat::semanticsMaxExponent(fltSemantics)) {
500  // Use cmp + select to replace infinites by int min / int max. Other
501  // integral values can be represented in the integer space.
502  auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
503  auto posInf = rewriter.create<arith::ConstantOp>(
504  loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
505  APFloat::getInf(fltSemantics)));
506  auto negInf = rewriter.create<arith::ConstantOp>(
507  loc, rewriter.getFloatAttr(
508  getElementTypeOrSelf(srcTy),
509  APFloat::getInf(fltSemantics, /*Negative=*/true)));
510  auto overflow = rewriter.create<arith::CmpFOp>(
511  loc, arith::CmpFPredicate::UEQ, rounded, posInf);
512  auto underflow = rewriter.create<arith::CmpFOp>(
513  loc, arith::CmpFPredicate::UEQ, rounded, negInf);
514  auto intMin = rewriter.create<arith::ConstantOp>(
515  loc, rewriter.getIntegerAttr(
516  getElementTypeOrSelf(dstTy),
517  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
518  auto intMax = rewriter.create<arith::ConstantOp>(
519  loc, rewriter.getIntegerAttr(
520  getElementTypeOrSelf(dstTy),
521  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
522  auto maxClamped =
523  rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
524  return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
525  maxClamped);
526  }
527 
528  auto intMinFP = rewriter.create<arith::ConstantOp>(
529  loc, rewriter.getFloatAttr(
530  getElementTypeOrSelf(srcTy),
531  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
532  .getSExtValue()));
533 
534  // Check whether the mantissa has enough bits to represent int max.
535  if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
536  dstTy.getIntOrFloatBitWidth() - 1) {
537  // Int min can also be represented since it is a power of two and thus
538  // consists of a single leading bit. Therefore we can clamp the input
539  // in the floating-point domain.
540 
541  auto intMaxFP = rewriter.create<arith::ConstantOp>(
542  loc, rewriter.getFloatAttr(
543  getElementTypeOrSelf(srcTy),
544  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
545  .getSExtValue()));
546 
547  Value clamped =
548  clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
549  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
550  }
551 
552  // Due to earlier check we know exponant range is big enough to represent
553  // int min. We can therefore rely on int max + 1 being representable as
554  // well because it's just int min with a positive sign. So clamp the min
555  // value and compare against that to select the max int value if needed.
556  auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
557  loc, rewriter.getFloatAttr(
558  getElementTypeOrSelf(srcTy),
559  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
560  .getSExtValue() +
561  1));
562 
563  auto intMax = rewriter.create<arith::ConstantOp>(
564  loc, rewriter.getIntegerAttr(
565  getElementTypeOrSelf(dstTy),
566  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
567  auto minClampedFP =
568  rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
569  auto minClamped =
570  rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
571  auto overflow = rewriter.create<arith::CmpFOp>(
572  loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
573  return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
574  minClamped);
575  }
576 
577  // Casting to boolean, integers need to only be checked as not-equal to
578  // zero.
579  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
580  Value zero = rewriter.create<arith::ConstantIntOp>(
581  loc, 0, srcTy.getIntOrFloatBitWidth());
582  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
583  args.front(), zero);
584  }
585 
586  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
587  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
588  std::nullopt);
589 
590  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
591  return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
592  }
593  }
594 
595  (void)rewriter.notifyMatchFailure(
596  op, "unhandled op for linalg body calculation for elementwise op");
597  return nullptr;
598 }
599 
600 static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
601  int64_t rank) {
602  // No need to expand if we are already at the desired rank
603  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
604  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
605  int64_t numExtraDims = rank - shapedType.getRank();
606  assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
607  if (!numExtraDims)
608  return tensor;
609 
610  // Compute reassociation indices
611  SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
612  shapedType.getRank());
613  int64_t index = 0;
614  for (index = 0; index <= numExtraDims; index++)
615  reassociationIndices[0].push_back(index);
616  for (size_t position = 1; position < reassociationIndices.size(); position++)
617  reassociationIndices[position].push_back(index++);
618 
619  // Compute result type
620  SmallVector<int64_t> resultShape;
621  for (index = 0; index < numExtraDims; index++)
622  resultShape.push_back(1);
623  for (auto size : shapedType.getShape())
624  resultShape.push_back(size);
625  auto resultType =
626  RankedTensorType::get(resultShape, shapedType.getElementType());
627 
628  // Emit 'tensor.expand_shape' op
629  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
630  reassociationIndices);
631 }
632 
634  Location loc, ValueRange operands,
635  int64_t rank) {
636  return llvm::map_to_vector(operands, [&](Value operand) {
637  return expandRank(rewriter, loc, operand, rank);
638  });
639 }
640 
642 
643 // Emit an 'arith.constant' op for the given index if it has not been created
644 // yet, or return an existing constant. This will prevent an excessive creation
645 // of redundant constants, easing readability of emitted code for unit tests.
647  IndexPool &indexPool, int64_t index) {
648  auto [it, inserted] = indexPool.try_emplace(index);
649  if (inserted)
650  it->second =
651  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
652  return it->second;
653 }
654 
656  IndexPool &indexPool, Value tensor, int64_t index) {
657  auto indexValue = createIndex(rewriter, loc, indexPool, index);
658  return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
659 }
660 
662  IndexPool &indexPool, Value tensor,
663  int64_t index) {
664  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
665  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
666  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
667  if (shapedType.isDynamicDim(index))
668  return getTensorDim(rewriter, loc, indexPool, tensor, index);
669  return rewriter.getIndexAttr(shapedType.getDimSize(index));
670 }
671 
672 static bool operandsAndResultsRanked(Operation *operation) {
673  auto isRanked = [](Value value) {
674  return isa<RankedTensorType>(value.getType());
675  };
676  return llvm::all_of(operation->getOperands(), isRanked) &&
677  llvm::all_of(operation->getResults(), isRanked);
678 }
679 
680 // Compute the runtime dimension size for dimension 'dim' of the output by
681 // inspecting input 'operands', all of which are expected to have the same rank.
682 // This function returns a pair {targetSize, masterOperand}.
683 //
684 // The runtime size of the output dimension is returned either as a statically
685 // computed attribute or as a runtime SSA value.
686 //
687 // If the target size was inferred directly from one dominating operand, that
688 // operand is returned in 'masterOperand'. If the target size is inferred from
689 // multiple operands, 'masterOperand' is set to nullptr.
690 static std::pair<OpFoldResult, Value>
692  ValueRange operands, int64_t dim) {
693  // If any input operand contains a static size greater than 1 for this
694  // dimension, that is the target size. An occurrence of an additional static
695  // dimension greater than 1 with a different value is undefined behavior.
696  for (auto operand : operands) {
697  auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
698  if (!ShapedType::isDynamic(size) && size > 1)
699  return {rewriter.getIndexAttr(size), operand};
700  }
701 
702  // Filter operands with dynamic dimension
703  auto operandsWithDynamicDim =
704  llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
705  return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
706  }));
707 
708  // If no operand has a dynamic dimension, it means all sizes were 1
709  if (operandsWithDynamicDim.empty())
710  return {rewriter.getIndexAttr(1), operands.front()};
711 
712  // Emit code that computes the runtime size for this dimension. If there is
713  // only one operand with a dynamic dimension, it is considered the master
714  // operand that determines the runtime size of the output dimension.
715  auto targetSize =
716  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
717  if (operandsWithDynamicDim.size() == 1)
718  return {targetSize, operandsWithDynamicDim[0]};
719 
720  // Calculate maximum size among all dynamic dimensions
721  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
722  auto nextSize =
723  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
724  targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
725  }
726  return {targetSize, nullptr};
727 }
728 
729 // Compute the runtime output size for all dimensions. This function returns
730 // a pair {targetShape, masterOperands}.
731 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
733  IndexPool &indexPool, ValueRange operands) {
734  assert(!operands.empty());
735  auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
736  SmallVector<OpFoldResult> targetShape;
737  SmallVector<Value> masterOperands;
738  for (auto dim : llvm::seq<int64_t>(0, rank)) {
739  auto [targetSize, masterOperand] =
740  computeTargetSize(rewriter, loc, indexPool, operands, dim);
741  targetShape.push_back(targetSize);
742  masterOperands.push_back(masterOperand);
743  }
744  return {targetShape, masterOperands};
745 }
746 
748  IndexPool &indexPool, Value operand,
749  int64_t dim, OpFoldResult targetSize,
750  Value masterOperand) {
751  // Nothing to do if this is a static dimension
752  auto rankedTensorType = cast<RankedTensorType>(operand.getType());
753  if (!rankedTensorType.isDynamicDim(dim))
754  return operand;
755 
756  // If the target size for this dimension was directly inferred by only taking
757  // this operand into account, there is no need to broadcast. This is an
758  // optimization that will prevent redundant control flow, and constitutes the
759  // main motivation for tracking "master operands".
760  if (operand == masterOperand)
761  return operand;
762 
763  // Affine maps for 'linalg.generic' op
764  auto rank = rankedTensorType.getRank();
765  SmallVector<AffineExpr> affineExprs;
766  for (auto index : llvm::seq<int64_t>(0, rank)) {
767  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
768  : rewriter.getAffineDimExpr(index);
769  affineExprs.push_back(affineExpr);
770  }
771  auto broadcastAffineMap =
772  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
773  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
774  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
775 
776  // Check if broadcast is necessary
777  auto one = createIndex(rewriter, loc, indexPool, 1);
778  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
779  auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
780  loc, arith::CmpIPredicate::eq, runtimeSize, one);
781 
782  // Emit 'then' region of 'scf.if'
783  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
784  // It is not safe to cache constants across regions.
785  // New constants could potentially violate dominance requirements.
786  IndexPool localPool;
787 
788  // Emit 'tensor.empty' op
789  SmallVector<OpFoldResult> outputTensorShape;
790  for (auto index : llvm::seq<int64_t>(0, rank)) {
791  auto size = index == dim ? targetSize
792  : getOrFoldTensorDim(rewriter, loc, localPool,
793  operand, index);
794  outputTensorShape.push_back(size);
795  }
796  Value outputTensor = opBuilder.create<tensor::EmptyOp>(
797  loc, outputTensorShape, rankedTensorType.getElementType());
798 
799  // Emit 'linalg.generic' op
800  auto resultTensor =
801  opBuilder
802  .create<linalg::GenericOp>(
803  loc, outputTensor.getType(), operand, outputTensor, affineMaps,
805  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
806  // Emit 'linalg.yield' op
807  opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
808  })
809  .getResult(0);
810 
811  // Cast to original operand type if necessary
812  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
813  loc, operand.getType(), resultTensor);
814 
815  // Emit 'scf.yield' op
816  opBuilder.create<scf::YieldOp>(loc, castResultTensor);
817  };
818 
819  // Emit 'else' region of 'scf.if'
820  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
821  opBuilder.create<scf::YieldOp>(loc, operand);
822  };
823 
824  // Emit 'scf.if' op
825  auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
826  emitThenRegion, emitElseRegion);
827  return ifOp.getResult(0);
828 }
829 
831  IndexPool &indexPool, Value operand,
832  ArrayRef<OpFoldResult> targetShape,
833  ArrayRef<Value> masterOperands) {
834  int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
835  assert((int64_t)targetShape.size() == rank);
836  assert((int64_t)masterOperands.size() == rank);
837  for (auto index : llvm::seq<int64_t>(0, rank))
838  operand =
839  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
840  targetShape[index], masterOperands[index]);
841  return operand;
842 }
843 
844 static SmallVector<Value>
846  IndexPool &indexPool, ValueRange operands,
847  ArrayRef<OpFoldResult> targetShape,
848  ArrayRef<Value> masterOperands) {
849  // No need to broadcast for unary operations
850  if (operands.size() == 1)
851  return operands;
852 
853  // Broadcast dynamic dimensions operand by operand
854  return llvm::map_to_vector(operands, [&](Value operand) {
855  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
856  targetShape, masterOperands);
857  });
858 }
859 
860 static LogicalResult
862  Operation *operation, ValueRange operands,
863  ArrayRef<OpFoldResult> targetShape,
864  const TypeConverter &converter) {
865  // Generate output tensor
866  auto resultType = cast_or_null<RankedTensorType>(
867  converter.convertType(operation->getResultTypes().front()));
868  if (!resultType) {
869  return rewriter.notifyMatchFailure(operation, "failed to convert type");
870  }
871  Value outputTensor = rewriter.create<tensor::EmptyOp>(
872  loc, targetShape, resultType.getElementType());
873 
874  // Create affine maps. Input affine maps broadcast static dimensions of size
875  // 1. The output affine map is an identity map.
876  //
877  auto rank = resultType.getRank();
878  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
879  auto shape = cast<ShapedType>(operand.getType()).getShape();
880  SmallVector<AffineExpr> affineExprs;
881  for (auto it : llvm::enumerate(shape)) {
882  auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
883  : rewriter.getAffineDimExpr(it.index());
884  affineExprs.push_back(affineExpr);
885  }
886  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
887  });
888  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
889 
890  // Emit 'linalg.generic' op
891  bool encounteredError = false;
892  auto linalgOp = rewriter.create<linalg::GenericOp>(
893  loc, outputTensor.getType(), operands, outputTensor, affineMaps,
895  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
897  operation, blockArgs.take_front(operation->getNumOperands()),
898  {resultType.getElementType()}, rewriter);
899  if (!opResult) {
900  encounteredError = true;
901  return;
902  }
903  opBuilder.create<linalg::YieldOp>(loc, opResult);
904  });
905  if (encounteredError)
906  return rewriter.notifyMatchFailure(
907  operation, "unable to create linalg.generic body for elementwise op");
908 
909  // Cast 'linalg.generic' result into original result type if needed
910  auto castResult = rewriter.createOrFold<tensor::CastOp>(
911  loc, resultType, linalgOp->getResult(0));
912  rewriter.replaceOp(operation, castResult);
913  return success();
914 }
915 
916 static LogicalResult
918  ConversionPatternRewriter &rewriter,
919  const TypeConverter &converter) {
920 
921  // Collect op properties
922  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
923  assert(operation->getNumOperands() >= 1 &&
924  "elementwise op expects at least 1 operand");
925  if (!operandsAndResultsRanked(operation))
926  return rewriter.notifyMatchFailure(operation,
927  "Unranked tensors not supported");
928 
929  // Lower operation
930  IndexPool indexPool;
931  auto loc = operation->getLoc();
932  auto rank =
933  cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
934  auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
935  auto [targetShape, masterOperands] =
936  computeTargetShape(rewriter, loc, indexPool, expandedOperands);
937  auto broadcastOperands = broadcastDynamicDimensions(
938  rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
939  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
940  targetShape, converter);
941 }
942 
943 // Returns the constant initial value for a given reduction operation. The
944 // attribute type varies depending on the element type required.
945 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
946  PatternRewriter &rewriter) {
947  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
948  return rewriter.getFloatAttr(elementTy, 0.0);
949 
950  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
951  return rewriter.getIntegerAttr(elementTy, 0);
952 
953  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
954  return rewriter.getFloatAttr(elementTy, 1.0);
955 
956  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
957  return rewriter.getIntegerAttr(elementTy, 1);
958 
959  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
960  return rewriter.getFloatAttr(
961  elementTy, APFloat::getLargest(
962  cast<FloatType>(elementTy).getFloatSemantics(), false));
963 
964  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
965  return rewriter.getIntegerAttr(
966  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
967 
968  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
969  return rewriter.getFloatAttr(
970  elementTy, APFloat::getLargest(
971  cast<FloatType>(elementTy).getFloatSemantics(), true));
972 
973  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
974  return rewriter.getIntegerAttr(
975  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
976 
977  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
978  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
979 
980  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
981  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
982 
983  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
984  return rewriter.getFloatAttr(
985  elementTy, APFloat::getLargest(
986  cast<FloatType>(elementTy).getFloatSemantics(), true));
987 
988  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
989  return rewriter.getIntegerAttr(
990  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
991 
992  return {};
993 }
994 
995 // Creates the body calculation for a reduction. The operations vary depending
996 // on the input type.
998  ValueRange args,
999  Type elementTy,
1000  PatternRewriter &rewriter) {
1001  Location loc = op->getLoc();
1002  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1003  return rewriter.create<arith::AddFOp>(loc, args);
1004  }
1005 
1006  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1007  return rewriter.create<arith::AddIOp>(loc, args);
1008  }
1009 
1010  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
1011  return rewriter.create<arith::MulFOp>(loc, args);
1012  }
1013 
1014  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
1015  return rewriter.create<arith::MulIOp>(loc, args);
1016  }
1017 
1018  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1019  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1020  }
1021 
1022  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1023  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1024  }
1025 
1026  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1027  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1028  }
1029 
1030  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1031  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1032  }
1033 
1034  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1035  return rewriter.create<arith::AndIOp>(loc, args);
1036 
1037  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1038  return rewriter.create<arith::OrIOp>(loc, args);
1039 
1040  return {};
1041 }
1042 
1043 // Performs the match and rewrite for reduction operations. This includes
1044 // declaring a correctly sized initial value, and the linalg.generic operation
1045 // that reduces across the specified axis.
1046 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1047  PatternRewriter &rewriter) {
1048  auto loc = op->getLoc();
1049  auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
1050  auto resultTy = cast<ShapedType>(op->getResult(0).getType());
1051  auto elementTy = resultTy.getElementType();
1052  Value input = op->getOperand(0);
1053 
1054  SmallVector<int64_t> reduceShape;
1055  SmallVector<Value> dynDims;
1056  for (unsigned i = 0; i < inputTy.getRank(); i++) {
1057  if (axis != i) {
1058  reduceShape.push_back(inputTy.getDimSize(i));
1059  if (inputTy.isDynamicDim(i))
1060  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1061  }
1062  }
1063 
1064  // First fill the output buffer with the init value.
1065  auto emptyTensor =
1066  rewriter
1067  .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1068  dynDims)
1069  .getResult();
1070 
1071  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1072  if (!fillValueAttr)
1073  return rewriter.notifyMatchFailure(
1074  op, "No initial value found for reduction operation");
1075 
1076  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1077  auto filledTensor = rewriter
1078  .create<linalg::FillOp>(loc, ValueRange{fillValue},
1079  ValueRange{emptyTensor})
1080  .result();
1081 
1082  bool didEncounterError = false;
1083  auto linalgOp = rewriter.create<linalg::ReduceOp>(
1084  loc, input, filledTensor, axis,
1085  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1087  op, blockArgs, elementTy, rewriter);
1088  if (result)
1089  didEncounterError = true;
1090 
1091  nestedBuilder.create<linalg::YieldOp>(loc, result);
1092  });
1093 
1094  if (!didEncounterError)
1095  return rewriter.notifyMatchFailure(
1096  op, "unable to create linalg.generic body for reduce op");
1097 
1098  SmallVector<ReassociationExprs, 4> reassociationMap;
1099  uint64_t expandInputRank =
1100  cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1101  reassociationMap.resize(expandInputRank);
1102 
1103  for (uint64_t i = 0; i < expandInputRank; i++) {
1104  int32_t dimToPush = i > axis ? i + 1 : i;
1105  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1106  }
1107 
1108  if (expandInputRank != 0) {
1109  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1110  reassociationMap[expandedDim].push_back(
1111  rewriter.getAffineDimExpr(expandedDim + 1));
1112  }
1113 
1114  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1115  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1116  // not have access to such information. This matters when handling dynamically
1117  // sized tensors.
1118  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1119  op, resultTy, linalgOp.getResults()[0], reassociationMap);
1120  return success();
1121 }
1122 
1123 namespace {
1124 
1125 template <typename SrcOp>
1126 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1127 public:
1130 
1131  LogicalResult
1132  matchAndRewrite(SrcOp op, OpAdaptor operands,
1133  ConversionPatternRewriter &rewriter) const final {
1135  op, operands.getOperands(), rewriter, *this->getTypeConverter());
1136  }
1137 };
1138 
1139 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1140 public:
1142 
1143  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1144  PatternRewriter &rewriter) const final {
1145  auto loc = op.getLoc();
1146  auto input = op.getInput();
1147  auto inputTy = cast<ShapedType>(op.getInput().getType());
1148  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1149  unsigned rank = inputTy.getRank();
1150 
1151  // This is an illegal configuration. terminate and log an error
1152  if (op.getDoubleRound() && !op.getScale32())
1153  return rewriter.notifyMatchFailure(
1154  op, "tosa.rescale requires scale32 for double_round to be true");
1155 
1156  SmallVector<Value> dynDims;
1157  for (int i = 0; i < outputTy.getRank(); i++) {
1158  if (outputTy.isDynamicDim(i)) {
1159  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1160  }
1161  }
1162 
1163  // The shift and multiplier values.
1164  SmallVector<int32_t> multiplierValues(op.getMultiplier());
1165  SmallVector<int8_t> shiftValues(op.getShift());
1166 
1167  // If we shift by more than the bitwidth, this just sets to 0.
1168  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1169  if (shiftValues[i] > 63) {
1170  shiftValues[i] = 0;
1171  multiplierValues[i] = 0;
1172  }
1173  }
1174 
1175  // Double round only occurs if shift is greater than 31, check that this
1176  // is ever true.
1177  bool doubleRound =
1178  op.getDoubleRound() &&
1179  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1180 
1181  SmallVector<AffineMap> indexingMaps = {
1182  rewriter.getMultiDimIdentityMap(rank)};
1183  SmallVector<Value, 4> genericInputs = {input};
1184 
1185  // If we are rescaling per-channel then we need to store the multiplier
1186  // values in a buffer.
1187  Value multiplierConstant;
1188  int64_t multiplierArg = 0;
1189  if (multiplierValues.size() == 1) {
1190  multiplierConstant = rewriter.create<arith::ConstantOp>(
1191  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1192  } else {
1193  SmallVector<AffineExpr, 2> multiplierExprs{
1194  rewriter.getAffineDimExpr(rank - 1)};
1195  auto multiplierType =
1196  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1197  rewriter.getI32Type());
1198  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1199  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1200 
1201  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1202  /*symbolCount=*/0, multiplierExprs,
1203  rewriter.getContext()));
1204 
1205  multiplierArg = indexingMaps.size() - 1;
1206  }
1207 
1208  // If we are rescaling per-channel then we need to store the shift
1209  // values in a buffer.
1210  Value shiftConstant;
1211  int64_t shiftArg = 0;
1212  if (shiftValues.size() == 1) {
1213  shiftConstant = rewriter.create<arith::ConstantOp>(
1214  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1215  } else {
1216  SmallVector<AffineExpr, 2> shiftExprs = {
1217  rewriter.getAffineDimExpr(rank - 1)};
1218  auto shiftType =
1219  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1220  rewriter.getIntegerType(8));
1221  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1222  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1223  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1224  /*symbolCount=*/0, shiftExprs,
1225  rewriter.getContext()));
1226  shiftArg = indexingMaps.size() - 1;
1227  }
1228 
1229  // Indexing maps for output values.
1230  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1231 
1232  // Construct the indexing maps needed for linalg.generic ops.
1233  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1234  loc, outputTy.getShape(), outputTy.getElementType(),
1235  ArrayRef<Value>({dynDims}));
1236 
1237  auto linalgOp = rewriter.create<linalg::GenericOp>(
1238  loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1239  getNParallelLoopsAttrs(rank),
1240  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1241  ValueRange blockArgs) {
1242  Value value = blockArgs[0];
1243  Type valueTy = value.getType();
1244 
1245  // For now we do all of our math in 64-bit. This is not optimal but
1246  // should be correct for now, consider computing correct bit depth
1247  // later.
1248  int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1249 
1250  auto inputZp = createConstFromIntAttribute<int32_t>(
1251  op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1252  nestedBuilder);
1253  auto outputZp = createConstFromIntAttribute<int32_t>(
1254  op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1255 
1256  Value multiplier = multiplierConstant ? multiplierConstant
1257  : blockArgs[multiplierArg];
1258  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1259 
1260  if (valueTy.getIntOrFloatBitWidth() < 32) {
1261  if (valueTy.isUnsignedInteger()) {
1262  value = nestedBuilder
1263  .create<UnrealizedConversionCastOp>(
1264  nestedLoc,
1265  nestedBuilder.getIntegerType(
1266  valueTy.getIntOrFloatBitWidth()),
1267  value)
1268  .getResult(0);
1269  value = nestedBuilder.create<arith::ExtUIOp>(
1270  nestedLoc, nestedBuilder.getI32Type(), value);
1271  } else {
1272  value = nestedBuilder.create<arith::ExtSIOp>(
1273  nestedLoc, nestedBuilder.getI32Type(), value);
1274  }
1275  }
1276 
1277  value =
1278  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1279 
1280  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1281  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1282  nestedBuilder.getBoolAttr(doubleRound));
1283 
1284  // Move to the new zero-point.
1285  value =
1286  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1287 
1288  // Saturate to the output size.
1289  IntegerType outIntType =
1290  cast<IntegerType>(blockArgs.back().getType());
1291  unsigned outBitWidth = outIntType.getWidth();
1292 
1293  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1294  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1295 
1296  // Unsigned integers have a difference output value.
1297  if (outIntType.isUnsignedInteger()) {
1298  intMin = 0;
1299  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1300  }
1301 
1302  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1303  loc, nestedBuilder.getI32IntegerAttr(intMin));
1304  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1305  loc, nestedBuilder.getI32IntegerAttr(intMax));
1306 
1307  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1308  nestedBuilder, /*isUnsigned=*/false);
1309 
1310  if (outIntType.getWidth() < 32) {
1311  value = nestedBuilder.create<arith::TruncIOp>(
1312  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1313  value);
1314 
1315  if (outIntType.isUnsignedInteger()) {
1316  value = nestedBuilder
1317  .create<UnrealizedConversionCastOp>(nestedLoc,
1318  outIntType, value)
1319  .getResult(0);
1320  }
1321  }
1322 
1323  nestedBuilder.create<linalg::YieldOp>(loc, value);
1324  });
1325 
1326  rewriter.replaceOp(op, linalgOp->getResults());
1327  return success();
1328  }
1329 };
1330 
1331 // Handle the resize case where the input is a 1x1 image. This case
1332 // can entirely avoiding having extract operations which target much
1333 // more difficult to optimize away.
1334 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1335 public:
1337 
1338  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1339  PatternRewriter &rewriter) const final {
1340  Location loc = op.getLoc();
1341  ImplicitLocOpBuilder builder(loc, rewriter);
1342  auto input = op.getInput();
1343  auto inputTy = cast<RankedTensorType>(input.getType());
1344  auto resultTy = cast<RankedTensorType>(op.getType());
1345  const bool isBilinear = op.getMode() == "BILINEAR";
1346 
1347  auto inputH = inputTy.getDimSize(1);
1348  auto inputW = inputTy.getDimSize(2);
1349  auto outputH = resultTy.getDimSize(1);
1350  auto outputW = resultTy.getDimSize(2);
1351 
1352  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1353  return rewriter.notifyMatchFailure(
1354  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1355 
1356  // TODO(suderman): These string values should be declared the TOSA dialect.
1357  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1358  return rewriter.notifyMatchFailure(
1359  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1360 
1361  if (inputTy == resultTy) {
1362  rewriter.replaceOp(op, input);
1363  return success();
1364  }
1365 
1366  ArrayRef<int64_t> scale = op.getScale();
1367 
1368  // Collapse the unit width and height away.
1369  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1370  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1371  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1372  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1373  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1374 
1375  auto collapseTy =
1376  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1377  inputTy.getElementType());
1378  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1379  reassociationMap);
1380 
1381  // Get any dynamic shapes that appear in the input format.
1382  llvm::SmallVector<Value> outputDynSize;
1383  if (inputTy.isDynamicDim(0))
1384  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1385  if (inputTy.isDynamicDim(3))
1386  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1387 
1388  // Generate the elementwise operation for casting scaling the input value.
1389  auto genericTy = collapseTy.clone(resultTy.getElementType());
1390  Value empty = builder.create<tensor::EmptyOp>(
1391  genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1392  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1393  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1394  utils::IteratorType::parallel);
1395 
1396  auto generic = builder.create<linalg::GenericOp>(
1397  genericTy, ValueRange{collapse}, ValueRange{empty},
1398  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1399  [=](OpBuilder &b, Location loc, ValueRange args) {
1400  Value value = args[0];
1401  // This is the quantized case.
1402  if (inputTy.getElementType() != resultTy.getElementType()) {
1403  value =
1404  b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1405 
1406  if (isBilinear && scale[0] != 0) {
1407  Value scaleY = b.create<arith::ConstantOp>(
1408  loc, b.getI32IntegerAttr(scale[0]));
1409  value = b.create<arith::MulIOp>(loc, value, scaleY);
1410  }
1411 
1412  if (isBilinear && scale[2] != 0) {
1413  Value scaleX = b.create<arith::ConstantOp>(
1414  loc, b.getI32IntegerAttr(scale[2]));
1415  value = b.create<arith::MulIOp>(loc, value, scaleX);
1416  }
1417  }
1418 
1419  b.create<linalg::YieldOp>(loc, value);
1420  });
1421 
1422  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1423  op, resultTy, generic.getResults()[0], reassociationMap);
1424  return success();
1425  }
1426 };
1427 
1428 // TOSA resize with width or height of 1 may be broadcasted to a wider
1429 // dimension. This is done by materializing a new tosa.resize without
1430 // the broadcasting behavior, and an explicit broadcast afterwards.
1431 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1432 public:
1434 
1435  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1436  PatternRewriter &rewriter) const final {
1437  Location loc = op.getLoc();
1438  ImplicitLocOpBuilder builder(loc, rewriter);
1439  auto input = op.getInput();
1440  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1441  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1442 
1443  if (!inputTy || !resultTy)
1444  return rewriter.notifyMatchFailure(op,
1445  "requires ranked input/output types");
1446 
1447  auto batch = inputTy.getDimSize(0);
1448  auto channels = inputTy.getDimSize(3);
1449  auto inputH = inputTy.getDimSize(1);
1450  auto inputW = inputTy.getDimSize(2);
1451  auto outputH = resultTy.getDimSize(1);
1452  auto outputW = resultTy.getDimSize(2);
1453 
1454  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1455  return rewriter.notifyMatchFailure(
1456  op, "tosa.resize has no broadcasting behavior");
1457 
1458  // For any dimension that is broadcastable we generate a width of 1
1459  // on the output.
1460  llvm::SmallVector<int64_t> resizeShape;
1461  resizeShape.push_back(batch);
1462  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1463  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1464  resizeShape.push_back(channels);
1465 
1466  auto resizeTy = resultTy.clone(resizeShape);
1467  auto resize =
1468  builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1469 
1470  // Collapse an unit result dims.
1471  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1472  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1473  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1474  if (inputH != 1)
1475  reassociationMap.push_back({});
1476  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1477  if (inputW != 1)
1478  reassociationMap.push_back({});
1479  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1480 
1481  llvm::SmallVector<int64_t> collapseShape{batch};
1482  if (inputH != 1)
1483  collapseShape.push_back(outputH);
1484  if (inputW != 1)
1485  collapseShape.push_back(outputW);
1486  collapseShape.push_back(channels);
1487 
1488  auto collapseTy = resultTy.clone(collapseShape);
1489  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1490  reassociationMap);
1491 
1492  // Broadcast the collapsed shape to the output result.
1493  llvm::SmallVector<Value> outputDynSize;
1494  if (inputTy.isDynamicDim(0))
1495  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1496  if (inputTy.isDynamicDim(3))
1497  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1498 
1499  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1500  utils::IteratorType::parallel);
1501  Value empty = builder.create<tensor::EmptyOp>(
1502  resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1503 
1504  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1505  if (inputH != 1)
1506  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1507  if (inputW != 1)
1508  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1509  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1510 
1511  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1512  inputExprs, rewriter.getContext());
1513 
1514  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1515  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1516  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1517  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1518  [=](OpBuilder &b, Location loc, ValueRange args) {
1519  Value value = args[0];
1520  b.create<linalg::YieldOp>(loc, value);
1521  });
1522 
1523  return success();
1524  }
1525 };
1526 
1527 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1528 public:
1530 
1531  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1532  PatternRewriter &rewriter) const final {
1533  Location loc = op.getLoc();
1534  ImplicitLocOpBuilder b(loc, rewriter);
1535  auto input = op.getInput();
1536  auto inputTy = cast<ShapedType>(input.getType());
1537  auto resultTy = cast<ShapedType>(op.getType());
1538  auto resultETy = resultTy.getElementType();
1539 
1540  bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1541  auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1542 
1543  auto imageH = inputTy.getShape()[1];
1544  auto imageW = inputTy.getShape()[2];
1545 
1546  auto dynamicDimsOr =
1547  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1548  if (!dynamicDimsOr.has_value())
1549  return rewriter.notifyMatchFailure(
1550  op, "unable to get dynamic dimensions of tosa.resize");
1551 
1552  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1553  return rewriter.notifyMatchFailure(
1554  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1555 
1556  SmallVector<AffineMap, 2> affineMaps = {
1557  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1558  auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1559  *dynamicDimsOr);
1560  auto genericOp = b.create<linalg::GenericOp>(
1561  resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1562  getNParallelLoopsAttrs(resultTy.getRank()));
1563  Value resize = genericOp.getResult(0);
1564 
1565  {
1566  OpBuilder::InsertionGuard regionGuard(b);
1567  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1568  TypeRange({resultETy}), loc);
1569  Value batch = b.create<linalg::IndexOp>(0);
1570  Value y = b.create<linalg::IndexOp>(1);
1571  Value x = b.create<linalg::IndexOp>(2);
1572  Value channel = b.create<linalg::IndexOp>(3);
1573 
1574  Value zeroI32 =
1575  b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1576  Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1577  Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1578  Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1579 
1580  Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1581  Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1582 
1583  ArrayRef<int64_t> offset = op.getOffset();
1584  ArrayRef<int64_t> border = op.getBorder();
1585  ArrayRef<int64_t> scale = op.getScale();
1586 
1587  Value yScaleN, yScaleD, xScaleN, xScaleD;
1588  yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1589  yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1590  xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1591  xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1592 
1593  Value yOffset, xOffset, yBorder, xBorder;
1594  yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1595  xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1596  yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1597  xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1598 
1599  // Compute the ix and dx values for both the X and Y dimensions.
1600  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1601  Value scaleN, Value scaleD, Value offset,
1602  int size, ImplicitLocOpBuilder &b) {
1603  if (size == 1) {
1604  index = zeroI32;
1605  delta = zeroFp;
1606  return;
1607  }
1608  // x = x * scale_d + offset;
1609  // ix = floor(x / scale_n)
1610  Value val = b.create<arith::MulIOp>(in, scaleD);
1611  val = b.create<arith::AddIOp>(val, offset);
1612  index = b.create<arith::FloorDivSIOp>(val, scaleN);
1613 
1614  // rx = x % scale_n
1615  // dx = rx / scale_n
1616  Value r = b.create<arith::RemSIOp>(val, scaleN);
1617  Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1618  Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1619  delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1620  };
1621 
1622  // Compute the ix and dx values for the X and Y dimensions - int case.
1623  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1624  Value scaleN, Value scaleD, Value offset,
1625  int size, ImplicitLocOpBuilder &b) {
1626  if (size == 1) {
1627  index = zeroI32;
1628  delta = zeroI32;
1629  return;
1630  }
1631  // x = x * scale_d + offset;
1632  // ix = floor(x / scale_n)
1633  // dx = x - ix * scale_n;
1634  Value val = b.create<arith::MulIOp>(in, scaleD);
1635  val = b.create<arith::AddIOp>(val, offset);
1636  index = b.create<arith::DivSIOp>(val, scaleN);
1637  delta = b.create<arith::MulIOp>(index, scaleN);
1638  delta = b.create<arith::SubIOp>(val, delta);
1639  };
1640 
1641  Value ix, iy, dx, dy;
1642  if (floatingPointMode) {
1643  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1644  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1645  } else {
1646  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1647  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1648  }
1649 
1650  if (op.getMode() == "NEAREST_NEIGHBOR") {
1651  auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1652 
1653  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1654  Value max, int size,
1655  ImplicitLocOpBuilder &b) -> Value {
1656  if (size == 1) {
1657  return b.create<arith::ConstantIndexOp>(0);
1658  }
1659 
1660  Value pred;
1661  if (floatingPointMode) {
1662  auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1663  pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1664  } else {
1665  Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1666  pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1667  dvalDouble, scale);
1668  }
1669 
1670  auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1671  val = b.create<arith::AddIOp>(val, offset);
1672  val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
1673  return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1674  };
1675 
1676  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1677  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1678 
1679  Value result = b.create<tensor::ExtractOp>(
1680  input, ValueRange{batch, iy, ix, channel});
1681 
1682  b.create<linalg::YieldOp>(result);
1683  } else {
1684  // The mode here must be BILINEAR.
1685  assert(op.getMode() == "BILINEAR");
1686 
1687  auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1688 
1689  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1691  val0 = in;
1692  val1 = b.create<arith::AddIOp>(val0, oneVal);
1693  val0 =
1694  clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1695  val1 =
1696  clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
1697  val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1698  val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1699  };
1700 
1701  // Linalg equivalent to the section below:
1702  // int16_t iy0 = apply_max(iy, 0);
1703  // int16_t iy1 = apply_min(iy + 1, IH - 1);
1704  // int16_t ix0 = apply_max(ix, 0);
1705  // int16_t ix1 = apply_min(ix + 1, IW - 1);
1706  Value x0, x1, y0, y1;
1707  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1708  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1709 
1710  Value y0x0 = b.create<tensor::ExtractOp>(
1711  input, ValueRange{batch, y0, x0, channel});
1712  Value y0x1 = b.create<tensor::ExtractOp>(
1713  input, ValueRange{batch, y0, x1, channel});
1714  Value y1x0 = b.create<tensor::ExtractOp>(
1715  input, ValueRange{batch, y1, x0, channel});
1716  Value y1x1 = b.create<tensor::ExtractOp>(
1717  input, ValueRange{batch, y1, x1, channel});
1718 
1719  if (floatingPointMode) {
1720  auto oneVal =
1721  b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1722  auto interpolate = [&](Value val0, Value val1, Value delta,
1723  int inputSize,
1724  ImplicitLocOpBuilder &b) -> Value {
1725  if (inputSize == 1)
1726  return val0;
1727  Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1728  Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1729  Value mul1 = b.create<arith::MulFOp>(val1, delta);
1730  return b.create<arith::AddFOp>(mul0, mul1);
1731  };
1732 
1733  // Linalg equivalent to the section below:
1734  // topAcc = v00 * (unit_x - dx);
1735  // topAcc += v01 * dx;
1736  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1737 
1738  // Linalg equivalent to the section below:
1739  // bottomAcc = v10 * (unit_x - dx);
1740  // bottomAcc += v11 * dx;
1741  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1742 
1743  // Linalg equivalent to the section below:
1744  // result = topAcc * (unit_y - dy) + bottomAcc * dy
1745  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1746  b.create<linalg::YieldOp>(result);
1747  } else {
1748  // Perform in quantized space.
1749  y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1750  y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1751  y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1752  y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1753 
1754  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1755  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1756  dx = b.create<arith::ExtSIOp>(resultETy, dx);
1757  dy = b.create<arith::ExtSIOp>(resultETy, dy);
1758  }
1759 
1760  Value yScaleNExt = yScaleN;
1761  Value xScaleNExt = xScaleN;
1762 
1763  const int64_t scaleBitwidth =
1764  xScaleN.getType().getIntOrFloatBitWidth();
1765  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1766  yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1767  xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1768  }
1769 
1770  auto interpolate = [](Value val0, Value val1, Value weight1,
1771  Value scale, int inputSize,
1772  ImplicitLocOpBuilder &b) -> Value {
1773  if (inputSize == 1)
1774  return b.create<arith::MulIOp>(val0, scale);
1775  Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1776  Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1777  Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1778  return b.create<arith::AddIOp>(mul0, mul1);
1779  };
1780 
1781  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1782  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1783  Value result =
1784  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1785  b.create<linalg::YieldOp>(result);
1786  }
1787  }
1788  }
1789 
1790  rewriter.replaceOp(op, resize);
1791  return success();
1792  }
1793 };
1794 
1795 // At the codegen level any identity operations should be removed. Any cases
1796 // where identity is load-bearing (e.g. cross device computation) should be
1797 // handled before lowering to codegen.
1798 template <typename SrcOp>
1799 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1800 public:
1802 
1803  LogicalResult matchAndRewrite(SrcOp op,
1804  PatternRewriter &rewriter) const final {
1805  rewriter.replaceOp(op, op.getOperation()->getOperands());
1806  return success();
1807  }
1808 };
1809 
1810 template <typename SrcOp>
1811 class ReduceConverter : public OpRewritePattern<SrcOp> {
1812 public:
1814 
1815  LogicalResult matchAndRewrite(SrcOp reduceOp,
1816  PatternRewriter &rewriter) const final {
1817  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1818  }
1819 };
1820 
1821 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1822 public:
1824 
1825  LogicalResult matchAndRewrite(tosa::ReverseOp op,
1826  PatternRewriter &rewriter) const final {
1827  auto loc = op.getLoc();
1828  Value input = op.getInput();
1829  auto inputTy = cast<ShapedType>(input.getType());
1830  auto resultTy = cast<ShapedType>(op.getType());
1831  auto axis = op.getAxis();
1832 
1833  SmallVector<Value> dynDims;
1834  for (int i = 0; i < inputTy.getRank(); i++) {
1835  if (inputTy.isDynamicDim(i)) {
1836  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1837  }
1838  }
1839 
1840  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1841 
1842  // First fill the output buffer with the init value.
1843  auto emptyTensor = rewriter
1844  .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1845  inputTy.getElementType(),
1846  ArrayRef<Value>({dynDims}))
1847  .getResult();
1848  SmallVector<AffineMap, 2> affineMaps = {
1849  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1850 
1851  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1852  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
1853  getNParallelLoopsAttrs(resultTy.getRank()),
1854  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1855  llvm::SmallVector<Value> indices;
1856  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1857  Value index =
1858  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1859  if (i == axis) {
1860  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1861  auto sizeMinusOne =
1862  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1863  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1864  index);
1865  }
1866 
1867  indices.push_back(index);
1868  }
1869 
1870  auto extract = nestedBuilder.create<tensor::ExtractOp>(
1871  nestedLoc, input, indices);
1872  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1873  extract.getResult());
1874  });
1875  return success();
1876  }
1877 };
1878 
1879 // This converter translate a tile operation to a reshape, broadcast, reshape.
1880 // The first reshape minimally expands each tiled dimension to include a
1881 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1882 // multiple.
1883 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1885 
1886  LogicalResult
1887  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1888  ConversionPatternRewriter &rewriter) const override {
1889  auto loc = op.getLoc();
1890  auto input = op.getInput1();
1891  auto inputTy = cast<ShapedType>(input.getType());
1892  auto inputShape = inputTy.getShape();
1893  auto resultTy = cast<ShapedType>(op.getType());
1894  auto elementTy = inputTy.getElementType();
1895  int64_t rank = inputTy.getRank();
1896 
1897  ArrayRef<int64_t> multiples = op.getMultiples();
1898 
1899  // Broadcast the newly added dimensions to their appropriate multiple.
1900  SmallVector<int64_t, 2> genericShape;
1901  for (int i = 0; i < rank; i++) {
1902  int64_t dim = multiples[i];
1903  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1904  genericShape.push_back(inputShape[i]);
1905  }
1906 
1907  SmallVector<Value> dynDims;
1908  for (int i = 0; i < inputTy.getRank(); i++) {
1909  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1910  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1911  }
1912  }
1913 
1914  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1915  op.getLoc(), genericShape, elementTy, dynDims);
1916 
1917  // We needs to map the input shape to the non-broadcasted dimensions.
1918  SmallVector<AffineExpr, 4> dimExprs;
1919  dimExprs.reserve(rank);
1920  for (unsigned i = 0; i < rank; ++i)
1921  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1922 
1923  auto readAffineMap =
1924  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1925  rewriter.getContext());
1926 
1927  SmallVector<AffineMap, 2> affineMaps = {
1928  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1929 
1930  auto genericOp = rewriter.create<linalg::GenericOp>(
1931  loc, RankedTensorType::get(genericShape, elementTy), input,
1932  ValueRange{emptyTensor}, affineMaps,
1933  getNParallelLoopsAttrs(genericShape.size()),
1934  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1935  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1936  });
1937 
1938  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1939  op, resultTy, genericOp.getResult(0),
1940  rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
1941  return success();
1942  }
1943 };
1944 
1945 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1946 // op, producing two output buffers.
1947 //
1948 // The first output buffer contains the index of the found maximum value. It is
1949 // initialized to 0 and is resulting integer type.
1950 //
1951 // The second output buffer contains the maximum value found. It is initialized
1952 // to the minimum representable value of the input element type. After being
1953 // populated by indexed_generic, this buffer is disgarded as only the index is
1954 // requested.
1955 //
1956 // The indexed_generic op updates both the maximum value and index if the
1957 // current value exceeds the running max.
1958 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1959 public:
1961 
1962  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1963  PatternRewriter &rewriter) const final {
1964  auto loc = argmaxOp.getLoc();
1965  Value input = argmaxOp.getInput();
1966  auto inputTy = cast<ShapedType>(input.getType());
1967  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1968  auto inElementTy = inputTy.getElementType();
1969  auto outElementTy = resultTy.getElementType();
1970  int axis = argmaxOp.getAxis();
1971  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1972 
1973  if (!isa<IntegerType>(outElementTy))
1974  return rewriter.notifyMatchFailure(
1975  argmaxOp,
1976  "tosa.arg_max to linalg.* requires integer-like result type");
1977 
1978  SmallVector<Value> dynDims;
1979  for (int i = 0; i < inputTy.getRank(); i++) {
1980  if (inputTy.isDynamicDim(i) && i != axis) {
1981  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1982  }
1983  }
1984 
1985  // First fill the output buffer for the index.
1986  auto emptyTensorIdx = rewriter
1987  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1988  outElementTy, dynDims)
1989  .getResult();
1990  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1991  loc, rewriter.getIntegerAttr(outElementTy, 0));
1992  auto filledTensorIdx =
1993  rewriter
1994  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
1995  ValueRange{emptyTensorIdx})
1996  .result();
1997 
1998  // Second fill the output buffer for the running max.
1999  auto emptyTensorMax = rewriter
2000  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2001  inElementTy, dynDims)
2002  .getResult();
2003  auto fillValueMaxAttr =
2004  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2005 
2006  if (!fillValueMaxAttr)
2007  return rewriter.notifyMatchFailure(
2008  argmaxOp, "unsupported tosa.argmax element type");
2009 
2010  auto fillValueMax =
2011  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2012  auto filledTensorMax =
2013  rewriter
2014  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2015  ValueRange{emptyTensorMax})
2016  .result();
2017 
2018  // We need to reduce along the arg-max axis, with parallel operations along
2019  // the rest.
2021  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2022  iteratorTypes[axis] = utils::IteratorType::reduction;
2023 
2024  SmallVector<AffineExpr, 2> srcExprs;
2025  SmallVector<AffineExpr, 2> dstExprs;
2026  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2027  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2028  if (axis != i)
2029  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2030  }
2031 
2032  bool didEncounterError = false;
2033  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2034  rewriter.getContext());
2035  auto linalgOp = rewriter.create<linalg::GenericOp>(
2036  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2037  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2038  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2039  ValueRange blockArgs) {
2040  auto newValue = blockArgs[0];
2041  auto oldIndex = blockArgs[1];
2042  auto oldValue = blockArgs[2];
2043 
2044  Value newIndex = rewriter.create<arith::IndexCastOp>(
2045  nestedLoc, oldIndex.getType(),
2046  rewriter.create<linalg::IndexOp>(loc, axis));
2047 
2048  Value predicate;
2049  if (isa<FloatType>(inElementTy)) {
2050  predicate = rewriter.create<arith::CmpFOp>(
2051  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2052  } else if (isa<IntegerType>(inElementTy)) {
2053  predicate = rewriter.create<arith::CmpIOp>(
2054  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2055  } else {
2056  didEncounterError = true;
2057  return;
2058  }
2059 
2060  auto resultMax = rewriter.create<arith::SelectOp>(
2061  nestedLoc, predicate, newValue, oldValue);
2062  auto resultIndex = rewriter.create<arith::SelectOp>(
2063  nestedLoc, predicate, newIndex, oldIndex);
2064  nestedBuilder.create<linalg::YieldOp>(
2065  nestedLoc, ValueRange({resultIndex, resultMax}));
2066  });
2067 
2068  if (didEncounterError)
2069  return rewriter.notifyMatchFailure(
2070  argmaxOp, "unsupported tosa.argmax element type");
2071 
2072  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2073  return success();
2074  }
2075 };
2076 
2077 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2078 public:
2080  LogicalResult
2081  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2082  ConversionPatternRewriter &rewriter) const final {
2083  auto input = adaptor.getOperands()[0];
2084  auto indices = adaptor.getOperands()[1];
2085 
2086  auto valuesTy =
2087  dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2088  auto resultTy = cast<ShapedType>(op.getType());
2089 
2090  if (!valuesTy)
2091  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2092 
2093  auto dynamicDims = inferDynamicDimsForGather(
2094  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2095 
2096  auto resultElementTy = resultTy.getElementType();
2097 
2098  auto loc = op.getLoc();
2099  auto emptyTensor =
2100  rewriter
2101  .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2102  dynamicDims)
2103  .getResult();
2104 
2105  SmallVector<AffineMap, 2> affineMaps = {
2107  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2108  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2109  rewriter.getContext()),
2110  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2111 
2112  auto genericOp = rewriter.create<linalg::GenericOp>(
2113  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2114  ValueRange{emptyTensor}, affineMaps,
2115  getNParallelLoopsAttrs(resultTy.getRank()),
2116  [&](OpBuilder &b, Location loc, ValueRange args) {
2117  auto indexValue = args[0];
2118  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2119  Value index1 = rewriter.create<arith::IndexCastOp>(
2120  loc, rewriter.getIndexType(), indexValue);
2121  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2122  Value extract = rewriter.create<tensor::ExtractOp>(
2123  loc, input, ValueRange{index0, index1, index2});
2124  rewriter.create<linalg::YieldOp>(loc, extract);
2125  });
2126  rewriter.replaceOp(op, genericOp.getResult(0));
2127  return success();
2128  }
2129 
2130  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2131  Location loc,
2132  Value values,
2133  Value indices) {
2134  llvm::SmallVector<Value> results;
2135 
2136  auto addDynamicDimension = [&](Value source, int64_t dim) {
2137  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2138  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2139  results.push_back(dimValue);
2140  };
2141 
2142  addDynamicDimension(values, 0);
2143  addDynamicDimension(indices, 1);
2144  addDynamicDimension(values, 2);
2145  return results;
2146  }
2147 };
2148 
2149 // Lowerings the TableOp to a series of gathers and numerica operations. This
2150 // includes interpolation between the high/low values. For the I8 varient, this
2151 // simplifies to a single gather operation.
2152 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2153 public:
2155 
2156  LogicalResult matchAndRewrite(tosa::TableOp op,
2157  PatternRewriter &rewriter) const final {
2158  auto loc = op.getLoc();
2159  Value input = op.getInput();
2160  Value table = op.getTable();
2161  auto inputTy = cast<ShapedType>(input.getType());
2162  auto tableTy = cast<ShapedType>(table.getType());
2163  auto resultTy = cast<ShapedType>(op.getType());
2164 
2165  auto inputElementTy = inputTy.getElementType();
2166  auto tableElementTy = tableTy.getElementType();
2167  auto resultElementTy = resultTy.getElementType();
2168 
2169  SmallVector<Value> dynDims;
2170  for (int i = 0; i < resultTy.getRank(); ++i) {
2171  if (inputTy.isDynamicDim(i)) {
2172  dynDims.push_back(
2173  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2174  }
2175  }
2176 
2177  auto emptyTensor = rewriter
2178  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2179  resultElementTy, dynDims)
2180  .getResult();
2181 
2182  SmallVector<AffineMap, 2> affineMaps = {
2183  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2184  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2185 
2186  auto genericOp = rewriter.create<linalg::GenericOp>(
2187  loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2188  getNParallelLoopsAttrs(resultTy.getRank()));
2189  rewriter.replaceOp(op, genericOp.getResult(0));
2190 
2191  {
2192  OpBuilder::InsertionGuard regionGuard(rewriter);
2193  Block *block = rewriter.createBlock(
2194  &genericOp.getRegion(), genericOp.getRegion().end(),
2195  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2196 
2197  auto inputValue = block->getArgument(0);
2198  rewriter.setInsertionPointToStart(block);
2199  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2200  resultElementTy.isInteger(8)) {
2201  Value index = rewriter.create<arith::IndexCastOp>(
2202  loc, rewriter.getIndexType(), inputValue);
2203  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2204  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2205  index, offset);
2206  Value extract =
2207  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2208  rewriter.create<linalg::YieldOp>(loc, extract);
2209  return success();
2210  }
2211 
2212  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2213  resultElementTy.isInteger(32)) {
2214  Value extend = rewriter.create<arith::ExtSIOp>(
2215  loc, rewriter.getI32Type(), inputValue);
2216 
2217  auto offset = rewriter.create<arith::ConstantOp>(
2218  loc, rewriter.getI32IntegerAttr(32768));
2219  auto seven = rewriter.create<arith::ConstantOp>(
2220  loc, rewriter.getI32IntegerAttr(7));
2221  auto one = rewriter.create<arith::ConstantOp>(
2222  loc, rewriter.getI32IntegerAttr(1));
2223  auto b1111111 = rewriter.create<arith::ConstantOp>(
2224  loc, rewriter.getI32IntegerAttr(127));
2225 
2226  // Compute the index and fractional part from the input value:
2227  // value = value + 32768
2228  // index = value >> 7;
2229  // fraction = 0x01111111 & value
2230  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2231  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2232  Value fraction =
2233  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2234 
2235  // Extract the base and next values from the table.
2236  // base = (int32_t) table[index];
2237  // next = (int32_t) table[index + 1];
2238  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2239 
2240  index = rewriter.create<arith::IndexCastOp>(
2241  loc, rewriter.getIndexType(), index);
2242  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2243  loc, rewriter.getIndexType(), indexPlusOne);
2244 
2245  Value base =
2246  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2247  Value next = rewriter.create<tensor::ExtractOp>(
2248  loc, table, ValueRange{indexPlusOne});
2249 
2250  base =
2251  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2252  next =
2253  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2254 
2255  // Use the fractional part to interpolate between the input values:
2256  // result = (base << 7) + (next - base) * fraction
2257  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2258  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2259  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2260  Value result =
2261  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2262 
2263  rewriter.create<linalg::YieldOp>(loc, result);
2264 
2265  return success();
2266  }
2267  }
2268 
2269  return rewriter.notifyMatchFailure(
2270  op, "unable to create body for tosa.table op");
2271  }
2272 };
2273 
2274 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2276 
2277  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2278 
2279  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2280  OpFoldResult ofr) {
2281  auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2282  auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2283 
2284  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2285  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2286  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2287  return getAsOpFoldResult(plusOne);
2288  }
2289 
2290  static RankedTensorType
2291  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2292  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2293  // Get [N, H, W]
2294  auto dims = tensor::getMixedSizes(builder, loc, input);
2295 
2296  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2297  // output tensors.
2298  dims[2] = halfPlusOne(builder, loc, dims[2]);
2299 
2300  llvm::SmallVector<int64_t, 3> staticSizes;
2301  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2302 
2303  auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2304  return RankedTensorType::get(staticSizes, elementType);
2305  }
2306 
2307  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2308  RankedTensorType type,
2309  llvm::ArrayRef<Value> dynamicSizes) {
2310  auto emptyTensor =
2311  rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2312  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2313  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2314  auto filledTensor = rewriter
2315  .create<linalg::FillOp>(loc, ValueRange{fillValue},
2316  ValueRange{emptyTensor})
2317  .result();
2318  return filledTensor;
2319  }
2320 
2321  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2322  FloatType type, Value value) {
2323  auto integerVal = builder.create<arith::IndexCastUIOp>(
2324  loc,
2325  type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2326  : builder.getI32Type(),
2327  value);
2328 
2329  return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2330  }
2331 
2332  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2333  FloatType type, int64_t index) {
2334  auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2335  return castIndexToFloat(builder, loc, type, indexVal);
2336  }
2337 
2338  template <typename... Args>
2339  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2340  Args... args) {
2341  return {builder.getAffineDimExpr(args)...};
2342  }
2343 
2344  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2345  PatternRewriter &rewriter) const override {
2346  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2347  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2348  return rewriter.notifyMatchFailure(rfft2d,
2349  "only supports ranked tensors");
2350  }
2351 
2352  auto loc = rfft2d.getLoc();
2353  auto input = rfft2d.getInput();
2354  auto elementType =
2355  dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2356  if (!elementType)
2357  return rewriter.notifyMatchFailure(rfft2d,
2358  "only supports float element types");
2359 
2360  // Compute the output type and set of dynamic sizes
2361  llvm::SmallVector<Value> dynamicSizes;
2362  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2363 
2364  // Iterator types for the linalg.generic implementation
2366  utils::IteratorType::parallel, utils::IteratorType::parallel,
2367  utils::IteratorType::parallel, utils::IteratorType::reduction,
2368  utils::IteratorType::reduction};
2369 
2370  // Inputs/outputs to the linalg.generic implementation
2371  llvm::SmallVector<Value> genericOpInputs = {input};
2372  llvm::SmallVector<Value> genericOpOutputs = {
2373  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2374  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2375 
2376  // Indexing maps for input and output tensors
2377  auto indexingMaps = AffineMap::inferFromExprList(
2378  llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2379  affineDimsExpr(rewriter, 0, 1, 2),
2380  affineDimsExpr(rewriter, 0, 1, 2)},
2381  rewriter.getContext());
2382 
2383  // Width and height dimensions of the original input.
2384  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2385  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2386 
2387  // Constants and dimension sizes
2388  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2389  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2390  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2391  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2392 
2393  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2394  Value valReal = args[0];
2395  Value sumReal = args[1];
2396  Value sumImag = args[2];
2397 
2398  // Indices for angle computation
2399  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2400  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2401  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2402  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2403 
2404  // Calculating angle without integer parts of components as sin/cos are
2405  // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2406  // / W);
2407  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2408  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2409 
2410  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2411  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2412 
2413  auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2414  auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2415 
2416  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2417  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2418  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2419  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2420 
2421  // realComponent = valReal * cos(angle)
2422  // imagComponent = valReal * sin(angle)
2423  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2424  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2425  auto realComponent =
2426  builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2427  auto imagComponent =
2428  builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2429 
2430  // outReal = sumReal + realComponent
2431  // outImag = sumImag - imagComponent
2432  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2433  auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2434 
2435  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2436  };
2437 
2438  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2439  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2440  indexingMaps, iteratorTypes, buildBody);
2441 
2442  return success();
2443  }
2444 };
2445 
2446 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2448 
2449  LogicalResult matchAndRewrite(FFT2dOp fft2d,
2450  PatternRewriter &rewriter) const override {
2451  if (!llvm::all_of(fft2d->getOperandTypes(),
2452  RFFT2dConverter::isRankedTensor) ||
2453  !llvm::all_of(fft2d->getResultTypes(),
2454  RFFT2dConverter::isRankedTensor)) {
2455  return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2456  }
2457 
2458  Location loc = fft2d.getLoc();
2459  Value input_real = fft2d.getInputReal();
2460  Value input_imag = fft2d.getInputImag();
2461  BoolAttr inverse = fft2d.getInverseAttr();
2462 
2463  auto real_el_ty = cast<FloatType>(
2464  cast<ShapedType>(input_real.getType()).getElementType());
2465  [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2466  cast<ShapedType>(input_imag.getType()).getElementType());
2467 
2468  assert(real_el_ty == imag_el_ty);
2469 
2470  // Compute the output type and set of dynamic sizes
2471  SmallVector<Value> dynamicSizes;
2472 
2473  // Get [N, H, W]
2474  auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2475 
2476  SmallVector<int64_t, 3> staticSizes;
2477  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2478 
2479  auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2480 
2481  // Iterator types for the linalg.generic implementation
2482  SmallVector<utils::IteratorType, 5> iteratorTypes = {
2483  utils::IteratorType::parallel, utils::IteratorType::parallel,
2484  utils::IteratorType::parallel, utils::IteratorType::reduction,
2485  utils::IteratorType::reduction};
2486 
2487  // Inputs/outputs to the linalg.generic implementation
2488  SmallVector<Value> genericOpInputs = {input_real, input_imag};
2489  SmallVector<Value> genericOpOutputs = {
2490  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2491  dynamicSizes),
2492  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2493  dynamicSizes)};
2494 
2495  // Indexing maps for input and output tensors
2496  auto indexingMaps = AffineMap::inferFromExprList(
2497  ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2498  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2499  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2500  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2501  rewriter.getContext());
2502 
2503  // Width and height dimensions of the original input.
2504  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2505  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2506 
2507  // Constants and dimension sizes
2508  auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2509  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2510  Value constH =
2511  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2512  Value constW =
2513  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2514 
2515  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2516  Value valReal = args[0];
2517  Value valImag = args[1];
2518  Value sumReal = args[2];
2519  Value sumImag = args[3];
2520 
2521  // Indices for angle computation
2522  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2523  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2524  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2525  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2526 
2527  // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2528  // ox) % W ) / W);
2529  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2530  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2531 
2532  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2533  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2534 
2535  auto iyRemFloat =
2536  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2537  auto ixRemFloat =
2538  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2539 
2540  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2541  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2542 
2543  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2544  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2545 
2546  if (inverse.getValue()) {
2547  angle = builder.create<arith::MulFOp>(
2548  loc, angle,
2549  rewriter.create<arith::ConstantOp>(
2550  loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2551  }
2552 
2553  // realComponent = val_real * cos(a) + val_imag * sin(a);
2554  // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2555  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2556  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2557 
2558  auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2559  auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2560  auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2561 
2562  auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2563  auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2564 
2565  auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2566 
2567  // outReal = sumReal + realComponent
2568  // outImag = sumImag - imagComponent
2569  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2570  auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2571 
2572  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2573  };
2574 
2575  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2576  fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2577  indexingMaps, iteratorTypes, buildBody);
2578 
2579  return success();
2580  }
2581 };
2582 
2583 } // namespace
2584 
2586  TypeConverter &converter, RewritePatternSet *patterns) {
2587 
2588  // We have multiple resize coverters to handle degenerate cases.
2589  patterns->add<GenericResizeConverter>(patterns->getContext(),
2590  /*benefit=*/100);
2591  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2592  /*benefit=*/200);
2593  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2594  /*benefit=*/300);
2595 
2596  patterns->add<
2597  // clang-format off
2598  PointwiseConverter<tosa::AddOp>,
2599  PointwiseConverter<tosa::SubOp>,
2600  PointwiseConverter<tosa::MulOp>,
2601  PointwiseConverter<tosa::IntDivOp>,
2602  PointwiseConverter<tosa::NegateOp>,
2603  PointwiseConverter<tosa::PowOp>,
2604  PointwiseConverter<tosa::ReciprocalOp>,
2605  PointwiseConverter<tosa::RsqrtOp>,
2606  PointwiseConverter<tosa::LogOp>,
2607  PointwiseConverter<tosa::ExpOp>,
2608  PointwiseConverter<tosa::AbsOp>,
2609  PointwiseConverter<tosa::SinOp>,
2610  PointwiseConverter<tosa::CosOp>,
2611  PointwiseConverter<tosa::TanhOp>,
2612  PointwiseConverter<tosa::ErfOp>,
2613  PointwiseConverter<tosa::BitwiseAndOp>,
2614  PointwiseConverter<tosa::BitwiseOrOp>,
2615  PointwiseConverter<tosa::BitwiseNotOp>,
2616  PointwiseConverter<tosa::BitwiseXorOp>,
2617  PointwiseConverter<tosa::LogicalAndOp>,
2618  PointwiseConverter<tosa::LogicalNotOp>,
2619  PointwiseConverter<tosa::LogicalOrOp>,
2620  PointwiseConverter<tosa::LogicalXorOp>,
2621  PointwiseConverter<tosa::CastOp>,
2622  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2623  PointwiseConverter<tosa::LogicalRightShiftOp>,
2624  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2625  PointwiseConverter<tosa::ClzOp>,
2626  PointwiseConverter<tosa::SelectOp>,
2627  PointwiseConverter<tosa::GreaterOp>,
2628  PointwiseConverter<tosa::GreaterEqualOp>,
2629  PointwiseConverter<tosa::EqualOp>,
2630  PointwiseConverter<tosa::MaximumOp>,
2631  PointwiseConverter<tosa::MinimumOp>,
2632  PointwiseConverter<tosa::CeilOp>,
2633  PointwiseConverter<tosa::FloorOp>,
2634  PointwiseConverter<tosa::ClampOp>,
2635  PointwiseConverter<tosa::SigmoidOp>
2636  >(converter, patterns->getContext());
2637 
2638  patterns->add<
2639  IdentityNConverter<tosa::IdentityOp>,
2640  ReduceConverter<tosa::ReduceAllOp>,
2641  ReduceConverter<tosa::ReduceAnyOp>,
2642  ReduceConverter<tosa::ReduceMinOp>,
2643  ReduceConverter<tosa::ReduceMaxOp>,
2644  ReduceConverter<tosa::ReduceSumOp>,
2645  ReduceConverter<tosa::ReduceProdOp>,
2646  ArgMaxConverter,
2647  GatherConverter,
2648  RescaleConverter,
2649  ReverseConverter,
2650  RFFT2dConverter,
2651  FFT2dConverter,
2652  TableConverter,
2653  TileConverter>(patterns->getContext());
2654  // clang-format on
2655 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, ValueRange operands, int64_t rank)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
const float * table
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:220
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:187
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:398
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:265
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:383
IntegerType getI64Type()
Definition: Builders.cpp:89
IntegerType getI32Type()
Definition: Builders.cpp:87
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:120
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:375
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:75
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:523
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:92
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:59
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:135
Fraction abs(const Fraction &f)
Definition: Fraction.h:106
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:55
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:607
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362