MLIR  19.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/PatternMatch.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/Sequence.h"
34 
35 #include <numeric>
36 
37 using namespace mlir;
38 using namespace mlir::tosa;
39 
40 template <typename T>
41 static arith::ConstantOp
42 createConstFromIntAttribute(Operation *op, const std::string &attrName,
43  Type requiredAttrType, OpBuilder &rewriter) {
44  auto castedN = static_cast<T>(
45  cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
46  return rewriter.create<arith::ConstantOp>(
47  op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
48 }
49 
50 static Value
52  ArrayRef<Type> resultTypes,
53  PatternRewriter &rewriter) {
54  Location loc = op->getLoc();
55  auto elementTy =
56  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
57 
58  // tosa::AbsOp
59  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
60  return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
61 
62  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
63  auto zero = rewriter.create<arith::ConstantOp>(
64  loc, rewriter.getZeroAttr(elementTy));
65  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
66  return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
67  }
68 
69  // tosa::AddOp
70  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
71  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
72 
73  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
74  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
75 
76  // tosa::SubOp
77  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
78  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
79 
80  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
81  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
82 
83  // tosa::MulOp
84  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
85  if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
86  (void)rewriter.notifyMatchFailure(op,
87  "Cannot have shift value for float");
88  return nullptr;
89  }
90  return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
91  }
92 
93  // tosa::DivOp
94  if (isa<tosa::DivOp>(op) && isa<IntegerType>(elementTy))
95  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
96 
97  // tosa::ReciprocalOp
98  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
99  auto one =
100  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
101  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
102  }
103 
104  if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
105  Value a = args[0];
106  Value b = args[1];
107  auto shift =
108  cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
109  if (shift > 0) {
110  auto shiftConst =
111  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
112  if (!a.getType().isInteger(32))
113  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
114 
115  if (!b.getType().isInteger(32))
116  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
117 
118  auto result = rewriter.create<tosa::ApplyScaleOp>(
119  loc, rewriter.getI32Type(), a, b, shiftConst,
120  rewriter.getBoolAttr(false));
121 
122  if (elementTy.isInteger(32))
123  return result;
124 
125  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
126  }
127 
128  int aWidth = a.getType().getIntOrFloatBitWidth();
129  int bWidth = b.getType().getIntOrFloatBitWidth();
130  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
131 
132  if (aWidth < cWidth)
133  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
134  if (bWidth < cWidth)
135  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
136 
137  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
138  }
139 
140  // tosa::NegateOp
141  if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
142  return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
143 
144  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
145  !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
146  auto constant =
147  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
148  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
149  }
150 
151  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
152  cast<tosa::NegateOp>(op).getQuantizationInfo()) {
153  auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
154  int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
155  int64_t inZp = quantizationInfo.value().getInputZp();
156  int64_t outZp = quantizationInfo.value().getOutputZp();
157 
158  // Compute the maximum value that can occur in the intermediate buffer.
159  int64_t zpAdd = inZp + outZp;
160  int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
161  std::abs(zpAdd) + 1;
162 
163  // Convert that maximum value into the maximum bitwidth needed to represent
164  // it. We assume 48-bit numbers may be supported further in the pipeline.
165  int intermediateBitWidth = 64;
166  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
167  intermediateBitWidth = 16;
168  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
169  intermediateBitWidth = 32;
170  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
171  intermediateBitWidth = 48;
172  }
173 
174  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
175  Value zpAddValue = rewriter.create<arith::ConstantOp>(
176  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
177 
178  // The negation can be applied by doing:
179  // outputValue = inZp + outZp - inputValue
180  auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
181  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
182 
183  // Clamp to the negation range.
184  Value min = rewriter.create<arith::ConstantIntOp>(
185  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
186  intermediateType);
187  Value max = rewriter.create<arith::ConstantIntOp>(
188  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
189  intermediateType);
190  auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
191 
192  // Truncate to the final value.
193  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
194  }
195 
196  // tosa::BitwiseAndOp
197  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
198  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
199 
200  // tosa::BitwiseOrOp
201  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
202  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
203 
204  // tosa::BitwiseNotOp
205  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
206  auto allOnesAttr = rewriter.getIntegerAttr(
207  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
208  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
209  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
210  }
211 
212  // tosa::BitwiseXOrOp
213  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
214  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
215 
216  // tosa::LogicalLeftShiftOp
217  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
218  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
219 
220  // tosa::LogicalRightShiftOp
221  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
222  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
223 
224  // tosa::ArithmeticRightShiftOp
225  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
226  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
227  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
228  if (!round) {
229  return result;
230  }
231 
232  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
233  auto one =
234  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
235  auto zero =
236  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
237  auto i1one =
238  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
239 
240  // Checking that input2 != 0
241  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
242  loc, arith::CmpIPredicate::sgt, args[1], zero);
243 
244  // Checking for the last bit of input1 to be 1
245  auto subtract =
246  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
247  auto shifted =
248  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
249  ->getResults();
250  auto truncated =
251  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
252  auto isInputOdd =
253  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
254 
255  auto shouldRound = rewriter.create<arith::AndIOp>(
256  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
257  auto extended =
258  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
259  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
260  }
261 
262  // tosa::ClzOp
263  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
264  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
265  }
266 
267  // tosa::LogicalAnd
268  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
269  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
270 
271  // tosa::LogicalNot
272  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
273  auto one = rewriter.create<arith::ConstantOp>(
274  loc, rewriter.getIntegerAttr(elementTy, 1));
275  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
276  }
277 
278  // tosa::LogicalOr
279  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
280  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
281 
282  // tosa::LogicalXor
283  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
284  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
285 
286  // tosa::PowOp
287  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
288  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
289 
290  // tosa::RsqrtOp
291  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
292  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
293 
294  // tosa::LogOp
295  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
296  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
297 
298  // tosa::ExpOp
299  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
300  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
301 
302  // tosa::TanhOp
303  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
304  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
305 
306  // tosa::ErfOp
307  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
308  return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
309 
310  // tosa::GreaterOp
311  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
312  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
313  args[0], args[1]);
314 
315  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
316  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
317  args[0], args[1]);
318 
319  // tosa::GreaterEqualOp
320  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
321  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
322  args[0], args[1]);
323 
324  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
325  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
326  args[0], args[1]);
327 
328  // tosa::EqualOp
329  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
330  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
331  args[0], args[1]);
332 
333  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
334  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
335  args[0], args[1]);
336 
337  // tosa::SelectOp
338  if (isa<tosa::SelectOp>(op)) {
339  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
340  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
341  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
342  }
343 
344  // tosa::MaximumOp
345  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
346  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
347  }
348 
349  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
350  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
351  }
352 
353  // tosa::MinimumOp
354  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
355  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
356  }
357 
358  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
359  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
360  }
361 
362  // tosa::CeilOp
363  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
364  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
365 
366  // tosa::FloorOp
367  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
368  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
369 
370  // tosa::ClampOp
371  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
372  bool losesInfo = false;
373  APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
374  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
375  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
376  APFloat::rmNearestTiesToEven, &losesInfo);
377  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
378  APFloat::rmNearestTiesToEven, &losesInfo);
379  auto min = rewriter.create<arith::ConstantOp>(
380  loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
381  auto max = rewriter.create<arith::ConstantOp>(
382  loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
383  return clampFloatHelper(loc, args[0], min, max, rewriter);
384  }
385 
386  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
387  auto intTy = cast<IntegerType>(elementTy);
388  int64_t min =
389  cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
390  int64_t max =
391  cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
392 
393  if (intTy.isUnsignedInteger()) {
394  min = std::max(min, (int64_t)0);
395  max = std::min(
396  max,
397  APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
398  } else {
399  min =
400  std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
401  .getSExtValue());
402  max =
403  std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
404  .getSExtValue());
405  }
406 
407  auto minVal = rewriter.create<arith::ConstantIntOp>(
408  loc, min, intTy.getIntOrFloatBitWidth());
409  auto maxVal = rewriter.create<arith::ConstantIntOp>(
410  loc, max, intTy.getIntOrFloatBitWidth());
411  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
412  }
413 
414  // tosa::SigmoidOp
415  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
416  auto one =
417  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
418  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
419  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
420  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
421  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
422  }
423 
424  // tosa::CastOp
425  if (isa<tosa::CastOp>(op)) {
426  Type srcTy = elementTy;
427  Type dstTy = resultTypes.front();
428  bool bitExtend =
430 
431  if (srcTy == dstTy)
432  return args.front();
433 
434  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
435  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
436  std::nullopt);
437 
438  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
439  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
440  std::nullopt);
441 
442  // 1-bit integers need to be treated as signless.
443  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
444  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
445  std::nullopt);
446 
447  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
448  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
449  std::nullopt);
450 
451  // Unsigned integers need an unrealized cast so that they can be passed
452  // to UIToFP.
453  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
454  auto unrealizedCast =
455  rewriter
456  .create<UnrealizedConversionCastOp>(
457  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
458  args[0])
459  .getResult(0);
460  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
461  unrealizedCast);
462  }
463 
464  // All other si-to-fp conversions should be handled by SIToFP.
465  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
466  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
467  std::nullopt);
468 
469  // Casting to boolean, floats need to only be checked as not-equal to zero.
470  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
471  Value zero = rewriter.create<arith::ConstantOp>(
472  loc, rewriter.getFloatAttr(srcTy, 0.0));
473  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
474  args.front(), zero);
475  }
476 
477  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
478  auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
479 
480  const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
481  // Check whether neither int min nor int max can be represented in the
482  // input floating-point type due to too short exponent range.
483  if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
484  APFloat::semanticsMaxExponent(fltSemantics)) {
485  // Use cmp + select to replace infinites by int min / int max. Other
486  // integral values can be represented in the integer space.
487  auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
488  auto posInf = rewriter.create<arith::ConstantOp>(
489  loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
490  APFloat::getInf(fltSemantics)));
491  auto negInf = rewriter.create<arith::ConstantOp>(
492  loc, rewriter.getFloatAttr(
493  getElementTypeOrSelf(srcTy),
494  APFloat::getInf(fltSemantics, /*Negative=*/true)));
495  auto overflow = rewriter.create<arith::CmpFOp>(
496  loc, arith::CmpFPredicate::UEQ, rounded, posInf);
497  auto underflow = rewriter.create<arith::CmpFOp>(
498  loc, arith::CmpFPredicate::UEQ, rounded, negInf);
499  auto intMin = rewriter.create<arith::ConstantOp>(
500  loc, rewriter.getIntegerAttr(
501  getElementTypeOrSelf(dstTy),
502  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
503  auto intMax = rewriter.create<arith::ConstantOp>(
504  loc, rewriter.getIntegerAttr(
505  getElementTypeOrSelf(dstTy),
506  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
507  auto maxClamped =
508  rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
509  return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
510  maxClamped);
511  }
512 
513  auto intMinFP = rewriter.create<arith::ConstantOp>(
514  loc, rewriter.getFloatAttr(
515  getElementTypeOrSelf(srcTy),
516  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
517  .getSExtValue()));
518 
519  // Check whether the mantissa has enough bits to represent int max.
520  if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
521  dstTy.getIntOrFloatBitWidth() - 1) {
522  // Int min can also be represented since it is a power of two and thus
523  // consists of a single leading bit. Therefore we can clamp the input
524  // in the floating-point domain.
525 
526  auto intMaxFP = rewriter.create<arith::ConstantOp>(
527  loc, rewriter.getFloatAttr(
528  getElementTypeOrSelf(srcTy),
529  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
530  .getSExtValue()));
531 
532  Value clamped =
533  clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
534  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
535  }
536 
537  // Due to earlier check we know exponant range is big enough to represent
538  // int min. We can therefore rely on int max + 1 being representable as
539  // well because it's just int min with a positive sign. So clamp the min
540  // value and compare against that to select the max int value if needed.
541  auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
542  loc, rewriter.getFloatAttr(
543  getElementTypeOrSelf(srcTy),
544  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
545  .getSExtValue() +
546  1));
547 
548  auto intMax = rewriter.create<arith::ConstantOp>(
549  loc, rewriter.getIntegerAttr(
550  getElementTypeOrSelf(dstTy),
551  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
552  auto minClampedFP =
553  rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
554  auto minClamped =
555  rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
556  auto overflow = rewriter.create<arith::CmpFOp>(
557  loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
558  return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
559  minClamped);
560  }
561 
562  // Casting to boolean, integers need to only be checked as not-equal to
563  // zero.
564  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
565  Value zero = rewriter.create<arith::ConstantIntOp>(
566  loc, 0, srcTy.getIntOrFloatBitWidth());
567  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
568  args.front(), zero);
569  }
570 
571  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
572  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
573  std::nullopt);
574 
575  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
576  return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
577  }
578  }
579 
580  (void)rewriter.notifyMatchFailure(
581  op, "unhandled op for linalg body calculation for elementwise op");
582  return nullptr;
583 }
584 
585 static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
586  int64_t rank) {
587  // No need to expand if we are already at the desired rank
588  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
589  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
590  int64_t numExtraDims = rank - shapedType.getRank();
591  assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
592  if (!numExtraDims)
593  return tensor;
594 
595  // Compute reassociation indices
596  SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
597  shapedType.getRank());
598  int64_t index = 0;
599  for (index = 0; index <= numExtraDims; index++)
600  reassociationIndices[0].push_back(index);
601  for (size_t position = 1; position < reassociationIndices.size(); position++)
602  reassociationIndices[position].push_back(index++);
603 
604  // Compute result type
605  SmallVector<int64_t> resultShape;
606  for (index = 0; index < numExtraDims; index++)
607  resultShape.push_back(1);
608  for (auto size : shapedType.getShape())
609  resultShape.push_back(size);
610  auto resultType =
611  RankedTensorType::get(resultShape, shapedType.getElementType());
612 
613  // Emit 'tensor.expand_shape' op
614  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
615  reassociationIndices);
616 }
617 
619  Location loc, Operation *operation) {
620  auto rank =
621  cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
622  return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
623  return expandRank(rewriter, loc, operand, rank);
624  });
625 }
626 
628 
629 // Emit an 'arith.constant' op for the given index if it has not been created
630 // yet, or return an existing constant. This will prevent an excessive creation
631 // of redundant constants, easing readability of emitted code for unit tests.
633  IndexPool &indexPool, int64_t index) {
634  auto [it, inserted] = indexPool.try_emplace(index);
635  if (inserted)
636  it->second =
637  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
638  return it->second;
639 }
640 
642  IndexPool &indexPool, Value tensor, int64_t index) {
643  auto indexValue = createIndex(rewriter, loc, indexPool, index);
644  return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
645 }
646 
648  IndexPool &indexPool, Value tensor,
649  int64_t index) {
650  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
651  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
652  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
653  if (shapedType.isDynamicDim(index))
654  return getTensorDim(rewriter, loc, indexPool, tensor, index);
655  return rewriter.getIndexAttr(shapedType.getDimSize(index));
656 }
657 
658 static bool operandsAndResultsRanked(Operation *operation) {
659  auto isRanked = [](Value value) {
660  return isa<RankedTensorType>(value.getType());
661  };
662  return llvm::all_of(operation->getOperands(), isRanked) &&
663  llvm::all_of(operation->getResults(), isRanked);
664 }
665 
666 // Compute the runtime dimension size for dimension 'dim' of the output by
667 // inspecting input 'operands', all of which are expected to have the same rank.
668 // This function returns a pair {targetSize, masterOperand}.
669 //
670 // The runtime size of the output dimension is returned either as a statically
671 // computed attribute or as a runtime SSA value.
672 //
673 // If the target size was inferred directly from one dominating operand, that
674 // operand is returned in 'masterOperand'. If the target size is inferred from
675 // multiple operands, 'masterOperand' is set to nullptr.
676 static std::pair<OpFoldResult, Value>
678  ValueRange operands, int64_t dim) {
679  // If any input operand contains a static size greater than 1 for this
680  // dimension, that is the target size. An occurrence of an additional static
681  // dimension greater than 1 with a different value is undefined behavior.
682  for (auto operand : operands) {
683  auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
684  if (!ShapedType::isDynamic(size) && size > 1)
685  return {rewriter.getIndexAttr(size), operand};
686  }
687 
688  // Filter operands with dynamic dimension
689  auto operandsWithDynamicDim =
690  llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
691  return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
692  }));
693 
694  // If no operand has a dynamic dimension, it means all sizes were 1
695  if (operandsWithDynamicDim.empty())
696  return {rewriter.getIndexAttr(1), operands.front()};
697 
698  // Emit code that computes the runtime size for this dimension. If there is
699  // only one operand with a dynamic dimension, it is considered the master
700  // operand that determines the runtime size of the output dimension.
701  auto targetSize =
702  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
703  if (operandsWithDynamicDim.size() == 1)
704  return {targetSize, operandsWithDynamicDim[0]};
705 
706  // Calculate maximum size among all dynamic dimensions
707  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
708  auto nextSize =
709  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
710  targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
711  }
712  return {targetSize, nullptr};
713 }
714 
715 // Compute the runtime output size for all dimensions. This function returns
716 // a pair {targetShape, masterOperands}.
717 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
719  IndexPool &indexPool, ValueRange operands) {
720  assert(!operands.empty());
721  auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
722  SmallVector<OpFoldResult> targetShape;
723  SmallVector<Value> masterOperands;
724  for (auto dim : llvm::seq<int64_t>(0, rank)) {
725  auto [targetSize, masterOperand] =
726  computeTargetSize(rewriter, loc, indexPool, operands, dim);
727  targetShape.push_back(targetSize);
728  masterOperands.push_back(masterOperand);
729  }
730  return {targetShape, masterOperands};
731 }
732 
734  IndexPool &indexPool, Value operand,
735  int64_t dim, OpFoldResult targetSize,
736  Value masterOperand) {
737  // Nothing to do if this is a static dimension
738  auto rankedTensorType = cast<RankedTensorType>(operand.getType());
739  if (!rankedTensorType.isDynamicDim(dim))
740  return operand;
741 
742  // If the target size for this dimension was directly inferred by only taking
743  // this operand into account, there is no need to broadcast. This is an
744  // optimization that will prevent redundant control flow, and constitutes the
745  // main motivation for tracking "master operands".
746  if (operand == masterOperand)
747  return operand;
748 
749  // Affine maps for 'linalg.generic' op
750  auto rank = rankedTensorType.getRank();
751  SmallVector<AffineExpr> affineExprs;
752  for (auto index : llvm::seq<int64_t>(0, rank)) {
753  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
754  : rewriter.getAffineDimExpr(index);
755  affineExprs.push_back(affineExpr);
756  }
757  auto broadcastAffineMap =
758  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
759  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
760  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
761 
762  // Check if broadcast is necessary
763  auto one = createIndex(rewriter, loc, indexPool, 1);
764  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
765  auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
766  loc, arith::CmpIPredicate::eq, runtimeSize, one);
767 
768  // Emit 'then' region of 'scf.if'
769  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
770  // It is not safe to cache constants across regions.
771  // New constants could potentially violate dominance requirements.
772  IndexPool localPool;
773 
774  // Emit 'tensor.empty' op
775  SmallVector<OpFoldResult> outputTensorShape;
776  for (auto index : llvm::seq<int64_t>(0, rank)) {
777  auto size = index == dim ? targetSize
778  : getOrFoldTensorDim(rewriter, loc, localPool,
779  operand, index);
780  outputTensorShape.push_back(size);
781  }
782  Value outputTensor = opBuilder.create<tensor::EmptyOp>(
783  loc, outputTensorShape, rankedTensorType.getElementType());
784 
785  // Emit 'linalg.generic' op
786  auto resultTensor =
787  opBuilder
788  .create<linalg::GenericOp>(
789  loc, outputTensor.getType(), operand, outputTensor, affineMaps,
791  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
792  // Emit 'linalg.yield' op
793  opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
794  })
795  .getResult(0);
796 
797  // Cast to original operand type if necessary
798  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
799  loc, operand.getType(), resultTensor);
800 
801  // Emit 'scf.yield' op
802  opBuilder.create<scf::YieldOp>(loc, castResultTensor);
803  };
804 
805  // Emit 'else' region of 'scf.if'
806  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
807  opBuilder.create<scf::YieldOp>(loc, operand);
808  };
809 
810  // Emit 'scf.if' op
811  auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
812  emitThenRegion, emitElseRegion);
813  return ifOp.getResult(0);
814 }
815 
817  IndexPool &indexPool, Value operand,
818  ArrayRef<OpFoldResult> targetShape,
819  ArrayRef<Value> masterOperands) {
820  int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
821  assert((int64_t)targetShape.size() == rank);
822  assert((int64_t)masterOperands.size() == rank);
823  for (auto index : llvm::seq<int64_t>(0, rank))
824  operand =
825  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
826  targetShape[index], masterOperands[index]);
827  return operand;
828 }
829 
830 static SmallVector<Value>
832  IndexPool &indexPool, ValueRange operands,
833  ArrayRef<OpFoldResult> targetShape,
834  ArrayRef<Value> masterOperands) {
835  // No need to broadcast for unary operations
836  if (operands.size() == 1)
837  return operands;
838 
839  // Broadcast dynamic dimensions operand by operand
840  return llvm::map_to_vector(operands, [&](Value operand) {
841  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
842  targetShape, masterOperands);
843  });
844 }
845 
846 static LogicalResult
848  Operation *operation, ValueRange operands,
849  ArrayRef<OpFoldResult> targetShape) {
850  // Generate output tensor
851  auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
852  Value outputTensor = rewriter.create<tensor::EmptyOp>(
853  loc, targetShape, resultType.getElementType());
854 
855  // Create affine maps. Input affine maps broadcast static dimensions of size
856  // 1. The output affine map is an identity map.
857  //
858  auto rank = resultType.getRank();
859  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
860  auto shape = cast<ShapedType>(operand.getType()).getShape();
861  SmallVector<AffineExpr> affineExprs;
862  for (auto it : llvm::enumerate(shape)) {
863  auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
864  : rewriter.getAffineDimExpr(it.index());
865  affineExprs.push_back(affineExpr);
866  }
867  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
868  });
869  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
870 
871  // Emit 'linalg.generic' op
872  bool encounteredError = false;
873  auto linalgOp = rewriter.create<linalg::GenericOp>(
874  loc, outputTensor.getType(), operands, outputTensor, affineMaps,
876  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
878  operation, blockArgs.take_front(operation->getNumOperands()),
879  {resultType.getElementType()}, rewriter);
880  if (!opResult) {
881  encounteredError = true;
882  return;
883  }
884  opBuilder.create<linalg::YieldOp>(loc, opResult);
885  });
886  if (encounteredError)
887  return rewriter.notifyMatchFailure(
888  operation, "unable to create linalg.generic body for elementwise op");
889 
890  // Cast 'linalg.generic' result into original result type if needed
891  auto castResult = rewriter.createOrFold<tensor::CastOp>(
892  loc, resultType, linalgOp->getResult(0));
893  rewriter.replaceOp(operation, castResult);
894  return success();
895 }
896 
897 static LogicalResult
899  PatternRewriter &rewriter) {
900 
901  // Collect op properties
902  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
903  assert(operation->getNumOperands() >= 1 &&
904  "elementwise op expects at least 1 operand");
905  if (!operandsAndResultsRanked(operation))
906  return rewriter.notifyMatchFailure(operation,
907  "Unranked tensors not supported");
908 
909  // Lower operation
910  IndexPool indexPool;
911  auto loc = operation->getLoc();
912  auto expandedOperands = expandInputRanks(rewriter, loc, operation);
913  auto [targetShape, masterOperands] =
914  computeTargetShape(rewriter, loc, indexPool, expandedOperands);
915  auto broadcastOperands = broadcastDynamicDimensions(
916  rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
917  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
918  targetShape);
919 }
920 
921 // Returns the constant initial value for a given reduction operation. The
922 // attribute type varies depending on the element type required.
923 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
924  PatternRewriter &rewriter) {
925  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
926  return rewriter.getFloatAttr(elementTy, 0.0);
927 
928  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
929  return rewriter.getIntegerAttr(elementTy, 0);
930 
931  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
932  return rewriter.getFloatAttr(elementTy, 1.0);
933 
934  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
935  return rewriter.getIntegerAttr(elementTy, 1);
936 
937  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
938  return rewriter.getFloatAttr(
939  elementTy, APFloat::getLargest(
940  cast<FloatType>(elementTy).getFloatSemantics(), false));
941 
942  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
943  return rewriter.getIntegerAttr(
944  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
945 
946  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
947  return rewriter.getFloatAttr(
948  elementTy, APFloat::getLargest(
949  cast<FloatType>(elementTy).getFloatSemantics(), true));
950 
951  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
952  return rewriter.getIntegerAttr(
953  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
954 
955  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
956  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
957 
958  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
959  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
960 
961  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
962  return rewriter.getFloatAttr(
963  elementTy, APFloat::getLargest(
964  cast<FloatType>(elementTy).getFloatSemantics(), true));
965 
966  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
967  return rewriter.getIntegerAttr(
968  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
969 
970  return {};
971 }
972 
973 // Creates the body calculation for a reduction. The operations vary depending
974 // on the input type.
976  ValueRange args,
977  Type elementTy,
978  PatternRewriter &rewriter) {
979  Location loc = op->getLoc();
980  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
981  return rewriter.create<arith::AddFOp>(loc, args);
982  }
983 
984  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
985  return rewriter.create<arith::AddIOp>(loc, args);
986  }
987 
988  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
989  return rewriter.create<arith::MulFOp>(loc, args);
990  }
991 
992  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
993  return rewriter.create<arith::MulIOp>(loc, args);
994  }
995 
996  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
997  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
998  }
999 
1000  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1001  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1002  }
1003 
1004  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1005  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1006  }
1007 
1008  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1009  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1010  }
1011 
1012  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1013  return rewriter.create<arith::AndIOp>(loc, args);
1014 
1015  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1016  return rewriter.create<arith::OrIOp>(loc, args);
1017 
1018  return {};
1019 }
1020 
1021 // Performs the match and rewrite for reduction operations. This includes
1022 // declaring a correctly sized initial value, and the linalg.generic operation
1023 // that reduces across the specified axis.
1025  PatternRewriter &rewriter) {
1026  auto loc = op->getLoc();
1027  auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
1028  auto resultTy = cast<ShapedType>(op->getResult(0).getType());
1029  auto elementTy = resultTy.getElementType();
1030  Value input = op->getOperand(0);
1031 
1032  SmallVector<int64_t> reduceShape;
1033  SmallVector<Value> dynDims;
1034  for (unsigned i = 0; i < inputTy.getRank(); i++) {
1035  if (axis != i) {
1036  reduceShape.push_back(inputTy.getDimSize(i));
1037  if (inputTy.isDynamicDim(i))
1038  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1039  }
1040  }
1041 
1042  // First fill the output buffer with the init value.
1043  auto emptyTensor =
1044  rewriter
1045  .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1046  dynDims)
1047  .getResult();
1048 
1049  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1050  if (!fillValueAttr)
1051  return rewriter.notifyMatchFailure(
1052  op, "No initial value found for reduction operation");
1053 
1054  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1055  auto filledTensor = rewriter
1056  .create<linalg::FillOp>(loc, ValueRange{fillValue},
1057  ValueRange{emptyTensor})
1058  .result();
1059 
1060  bool didEncounterError = false;
1061  auto linalgOp = rewriter.create<linalg::ReduceOp>(
1062  loc, input, filledTensor, axis,
1063  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1065  op, blockArgs, elementTy, rewriter);
1066  if (result)
1067  didEncounterError = true;
1068 
1069  nestedBuilder.create<linalg::YieldOp>(loc, result);
1070  });
1071 
1072  if (!didEncounterError)
1073  return rewriter.notifyMatchFailure(
1074  op, "unable to create linalg.generic body for reduce op");
1075 
1076  SmallVector<ReassociationExprs, 4> reassociationMap;
1077  uint64_t expandInputRank =
1078  cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1079  reassociationMap.resize(expandInputRank);
1080 
1081  for (uint64_t i = 0; i < expandInputRank; i++) {
1082  int32_t dimToPush = i > axis ? i + 1 : i;
1083  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1084  }
1085 
1086  if (expandInputRank != 0) {
1087  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1088  reassociationMap[expandedDim].push_back(
1089  rewriter.getAffineDimExpr(expandedDim + 1));
1090  }
1091 
1092  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1093  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1094  // not have access to such information. This matters when handling dynamically
1095  // sized tensors.
1096  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1097  op, resultTy, linalgOp.getResults()[0], reassociationMap);
1098  return success();
1099 }
1100 
1101 namespace {
1102 
1103 template <typename SrcOp>
1104 class PointwiseConverter : public OpRewritePattern<SrcOp> {
1105 public:
1107 
1108  LogicalResult matchAndRewrite(SrcOp op,
1109  PatternRewriter &rewriter) const final {
1110  return elementwiseMatchAndRewriteHelper(op, rewriter);
1111  }
1112 };
1113 
1114 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1115 public:
1117 
1118  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1119  PatternRewriter &rewriter) const final {
1120  auto loc = op.getLoc();
1121  auto input = op.getInput();
1122  auto inputTy = cast<ShapedType>(op.getInput().getType());
1123  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1124  unsigned rank = inputTy.getRank();
1125 
1126  // This is an illegal configuration. terminate and log an error
1127  if (op.getDoubleRound() && !op.getScale32())
1128  return rewriter.notifyMatchFailure(
1129  op, "tosa.rescale requires scale32 for double_round to be true");
1130 
1131  SmallVector<Value> dynDims;
1132  for (int i = 0; i < outputTy.getRank(); i++) {
1133  if (outputTy.isDynamicDim(i)) {
1134  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1135  }
1136  }
1137 
1138  // The shift and multiplier values.
1139  SmallVector<int32_t> multiplierValues(op.getMultiplier());
1140  SmallVector<int8_t> shiftValues(op.getShift());
1141 
1142  // If we shift by more than the bitwidth, this just sets to 0.
1143  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1144  if (shiftValues[i] > 63) {
1145  shiftValues[i] = 0;
1146  multiplierValues[i] = 0;
1147  }
1148  }
1149 
1150  // Double round only occurs if shift is greater than 31, check that this
1151  // is ever true.
1152  bool doubleRound =
1153  op.getDoubleRound() &&
1154  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1155 
1156  SmallVector<AffineMap> indexingMaps = {
1157  rewriter.getMultiDimIdentityMap(rank)};
1158  SmallVector<Value, 4> genericInputs = {input};
1159 
1160  // If we are rescaling per-channel then we need to store the multiplier
1161  // values in a buffer.
1162  Value multiplierConstant;
1163  int64_t multiplierArg = 0;
1164  if (multiplierValues.size() == 1) {
1165  multiplierConstant = rewriter.create<arith::ConstantOp>(
1166  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1167  } else {
1168  SmallVector<AffineExpr, 2> multiplierExprs{
1169  rewriter.getAffineDimExpr(rank - 1)};
1170  auto multiplierType =
1171  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1172  rewriter.getI32Type());
1173  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1174  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1175 
1176  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1177  /*symbolCount=*/0, multiplierExprs,
1178  rewriter.getContext()));
1179 
1180  multiplierArg = indexingMaps.size() - 1;
1181  }
1182 
1183  // If we are rescaling per-channel then we need to store the shift
1184  // values in a buffer.
1185  Value shiftConstant;
1186  int64_t shiftArg = 0;
1187  if (shiftValues.size() == 1) {
1188  shiftConstant = rewriter.create<arith::ConstantOp>(
1189  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1190  } else {
1191  SmallVector<AffineExpr, 2> shiftExprs = {
1192  rewriter.getAffineDimExpr(rank - 1)};
1193  auto shiftType =
1194  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1195  rewriter.getIntegerType(8));
1196  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1197  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1198  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1199  /*symbolCount=*/0, shiftExprs,
1200  rewriter.getContext()));
1201  shiftArg = indexingMaps.size() - 1;
1202  }
1203 
1204  // Indexing maps for output values.
1205  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1206 
1207  // Construct the indexing maps needed for linalg.generic ops.
1208  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1209  loc, outputTy.getShape(), outputTy.getElementType(),
1210  ArrayRef<Value>({dynDims}));
1211 
1212  auto linalgOp = rewriter.create<linalg::GenericOp>(
1213  loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1214  getNParallelLoopsAttrs(rank),
1215  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1216  ValueRange blockArgs) {
1217  Value value = blockArgs[0];
1218  Type valueTy = value.getType();
1219 
1220  // For now we do all of our math in 64-bit. This is not optimal but
1221  // should be correct for now, consider computing correct bit depth
1222  // later.
1223  int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1224 
1225  auto inputZp = createConstFromIntAttribute<int32_t>(
1226  op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1227  nestedBuilder);
1228  auto outputZp = createConstFromIntAttribute<int32_t>(
1229  op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1230 
1231  Value multiplier = multiplierConstant ? multiplierConstant
1232  : blockArgs[multiplierArg];
1233  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1234 
1235  if (valueTy.getIntOrFloatBitWidth() < 32) {
1236  if (valueTy.isUnsignedInteger()) {
1237  value = nestedBuilder
1238  .create<UnrealizedConversionCastOp>(
1239  nestedLoc,
1240  nestedBuilder.getIntegerType(
1241  valueTy.getIntOrFloatBitWidth()),
1242  value)
1243  .getResult(0);
1244  value = nestedBuilder.create<arith::ExtUIOp>(
1245  nestedLoc, nestedBuilder.getI32Type(), value);
1246  } else {
1247  value = nestedBuilder.create<arith::ExtSIOp>(
1248  nestedLoc, nestedBuilder.getI32Type(), value);
1249  }
1250  }
1251 
1252  value =
1253  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1254 
1255  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1256  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1257  nestedBuilder.getBoolAttr(doubleRound));
1258 
1259  // Move to the new zero-point.
1260  value =
1261  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1262 
1263  // Saturate to the output size.
1264  IntegerType outIntType =
1265  cast<IntegerType>(blockArgs.back().getType());
1266  unsigned outBitWidth = outIntType.getWidth();
1267 
1268  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1269  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1270 
1271  // Unsigned integers have a difference output value.
1272  if (outIntType.isUnsignedInteger()) {
1273  intMin = 0;
1274  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1275  }
1276 
1277  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1278  loc, nestedBuilder.getI32IntegerAttr(intMin));
1279  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1280  loc, nestedBuilder.getI32IntegerAttr(intMax));
1281 
1282  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1283  nestedBuilder);
1284 
1285  if (outIntType.getWidth() < 32) {
1286  value = nestedBuilder.create<arith::TruncIOp>(
1287  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1288  value);
1289 
1290  if (outIntType.isUnsignedInteger()) {
1291  value = nestedBuilder
1292  .create<UnrealizedConversionCastOp>(nestedLoc,
1293  outIntType, value)
1294  .getResult(0);
1295  }
1296  }
1297 
1298  nestedBuilder.create<linalg::YieldOp>(loc, value);
1299  });
1300 
1301  rewriter.replaceOp(op, linalgOp->getResults());
1302  return success();
1303  }
1304 };
1305 
1306 // Handle the resize case where the input is a 1x1 image. This case
1307 // can entirely avoiding having extract operations which target much
1308 // more difficult to optimize away.
1309 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1310 public:
1312 
1313  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1314  PatternRewriter &rewriter) const final {
1315  Location loc = op.getLoc();
1316  ImplicitLocOpBuilder builder(loc, rewriter);
1317  auto input = op.getInput();
1318  auto inputTy = cast<RankedTensorType>(input.getType());
1319  auto resultTy = cast<RankedTensorType>(op.getType());
1320  const bool isBilinear = op.getMode() == "BILINEAR";
1321 
1322  auto inputH = inputTy.getDimSize(1);
1323  auto inputW = inputTy.getDimSize(2);
1324  auto outputH = resultTy.getDimSize(1);
1325  auto outputW = resultTy.getDimSize(2);
1326 
1327  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1328  return rewriter.notifyMatchFailure(
1329  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1330 
1331  // TODO(suderman): These string values should be declared the TOSA dialect.
1332  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1333  return rewriter.notifyMatchFailure(
1334  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1335 
1336  if (inputTy == resultTy) {
1337  rewriter.replaceOp(op, input);
1338  return success();
1339  }
1340 
1341  ArrayRef<int64_t> scale = op.getScale();
1342 
1343  // Collapse the unit width and height away.
1344  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1345  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1346  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1347  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1348  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1349 
1350  auto collapseTy =
1351  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1352  inputTy.getElementType());
1353  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1354  reassociationMap);
1355 
1356  // Get any dynamic shapes that appear in the input format.
1357  llvm::SmallVector<Value> outputDynSize;
1358  if (inputTy.isDynamicDim(0))
1359  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1360  if (inputTy.isDynamicDim(3))
1361  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1362 
1363  // Generate the elementwise operation for casting scaling the input value.
1364  auto genericTy = collapseTy.clone(resultTy.getElementType());
1365  Value empty = builder.create<tensor::EmptyOp>(
1366  genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1367  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1368  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1369  utils::IteratorType::parallel);
1370 
1371  auto generic = builder.create<linalg::GenericOp>(
1372  genericTy, ValueRange{collapse}, ValueRange{empty},
1373  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1374  [=](OpBuilder &b, Location loc, ValueRange args) {
1375  Value value = args[0];
1376  // This is the quantized case.
1377  if (inputTy.getElementType() != resultTy.getElementType()) {
1378  value =
1379  b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1380 
1381  if (isBilinear && scale[0] != 0) {
1382  Value scaleY = b.create<arith::ConstantOp>(
1383  loc, b.getI32IntegerAttr(scale[0]));
1384  value = b.create<arith::MulIOp>(loc, value, scaleY);
1385  }
1386 
1387  if (isBilinear && scale[2] != 0) {
1388  Value scaleX = b.create<arith::ConstantOp>(
1389  loc, b.getI32IntegerAttr(scale[2]));
1390  value = b.create<arith::MulIOp>(loc, value, scaleX);
1391  }
1392  }
1393 
1394  b.create<linalg::YieldOp>(loc, value);
1395  });
1396 
1397  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1398  op, resultTy, generic.getResults()[0], reassociationMap);
1399  return success();
1400  }
1401 };
1402 
1403 // TOSA resize with width or height of 1 may be broadcasted to a wider
1404 // dimension. This is done by materializing a new tosa.resize without
1405 // the broadcasting behavior, and an explicit broadcast afterwards.
1406 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1407 public:
1409 
1410  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1411  PatternRewriter &rewriter) const final {
1412  Location loc = op.getLoc();
1413  ImplicitLocOpBuilder builder(loc, rewriter);
1414  auto input = op.getInput();
1415  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1416  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1417 
1418  if (!inputTy || !resultTy)
1419  return rewriter.notifyMatchFailure(op,
1420  "requires ranked input/output types");
1421 
1422  auto batch = inputTy.getDimSize(0);
1423  auto channels = inputTy.getDimSize(3);
1424  auto inputH = inputTy.getDimSize(1);
1425  auto inputW = inputTy.getDimSize(2);
1426  auto outputH = resultTy.getDimSize(1);
1427  auto outputW = resultTy.getDimSize(2);
1428 
1429  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1430  return rewriter.notifyMatchFailure(
1431  op, "tosa.resize has no broadcasting behavior");
1432 
1433  // For any dimension that is broadcastable we generate a width of 1
1434  // on the output.
1435  llvm::SmallVector<int64_t> resizeShape;
1436  resizeShape.push_back(batch);
1437  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1438  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1439  resizeShape.push_back(channels);
1440 
1441  auto resizeTy = resultTy.clone(resizeShape);
1442  auto resize =
1443  builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1444 
1445  // Collapse an unit result dims.
1446  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1447  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1448  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1449  if (inputH != 1)
1450  reassociationMap.push_back({});
1451  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1452  if (inputW != 1)
1453  reassociationMap.push_back({});
1454  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1455 
1456  llvm::SmallVector<int64_t> collapseShape{batch};
1457  if (inputH != 1)
1458  collapseShape.push_back(outputH);
1459  if (inputW != 1)
1460  collapseShape.push_back(outputW);
1461  collapseShape.push_back(channels);
1462 
1463  auto collapseTy = resultTy.clone(collapseShape);
1464  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1465  reassociationMap);
1466 
1467  // Broadcast the collapsed shape to the output result.
1468  llvm::SmallVector<Value> outputDynSize;
1469  if (inputTy.isDynamicDim(0))
1470  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1471  if (inputTy.isDynamicDim(3))
1472  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1473 
1474  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1475  utils::IteratorType::parallel);
1476  Value empty = builder.create<tensor::EmptyOp>(
1477  resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1478 
1479  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1480  if (inputH != 1)
1481  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1482  if (inputW != 1)
1483  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1484  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1485 
1486  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1487  inputExprs, rewriter.getContext());
1488 
1489  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1490  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1491  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1492  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1493  [=](OpBuilder &b, Location loc, ValueRange args) {
1494  Value value = args[0];
1495  b.create<linalg::YieldOp>(loc, value);
1496  });
1497 
1498  return success();
1499  }
1500 };
1501 
1502 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1503 public:
1505 
1506  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1507  PatternRewriter &rewriter) const final {
1508  Location loc = op.getLoc();
1509  ImplicitLocOpBuilder b(loc, rewriter);
1510  auto input = op.getInput();
1511  auto inputTy = cast<ShapedType>(input.getType());
1512  auto resultTy = cast<ShapedType>(op.getType());
1513  auto resultETy = resultTy.getElementType();
1514 
1515  bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1516  auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1517 
1518  auto imageH = inputTy.getShape()[1];
1519  auto imageW = inputTy.getShape()[2];
1520 
1521  auto dynamicDimsOr =
1522  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1523  if (!dynamicDimsOr.has_value())
1524  return rewriter.notifyMatchFailure(
1525  op, "unable to get dynamic dimensions of tosa.resize");
1526 
1527  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1528  return rewriter.notifyMatchFailure(
1529  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1530 
1531  SmallVector<AffineMap, 2> affineMaps = {
1532  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1533  auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1534  *dynamicDimsOr);
1535  auto genericOp = b.create<linalg::GenericOp>(
1536  resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1537  getNParallelLoopsAttrs(resultTy.getRank()));
1538  Value resize = genericOp.getResult(0);
1539 
1540  {
1541  OpBuilder::InsertionGuard regionGuard(b);
1542  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1543  TypeRange({resultETy}), loc);
1544  Value batch = b.create<linalg::IndexOp>(0);
1545  Value y = b.create<linalg::IndexOp>(1);
1546  Value x = b.create<linalg::IndexOp>(2);
1547  Value channel = b.create<linalg::IndexOp>(3);
1548 
1549  Value zeroI32 =
1550  b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1551  Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1552  Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1553  Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1554 
1555  Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1556  Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1557 
1558  ArrayRef<int64_t> offset = op.getOffset();
1559  ArrayRef<int64_t> border = op.getBorder();
1560  ArrayRef<int64_t> scale = op.getScale();
1561 
1562  Value yScaleN, yScaleD, xScaleN, xScaleD;
1563  yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1564  yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1565  xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1566  xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1567 
1568  Value yOffset, xOffset, yBorder, xBorder;
1569  yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1570  xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1571  yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1572  xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1573 
1574  // Compute the ix and dx values for both the X and Y dimensions.
1575  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1576  Value scaleN, Value scaleD, Value offset,
1577  int size, ImplicitLocOpBuilder &b) {
1578  if (size == 1) {
1579  index = zeroI32;
1580  delta = zeroFp;
1581  return;
1582  }
1583  // x = x * scale_d + offset;
1584  // ix = floor(x / scale_n)
1585  Value val = b.create<arith::MulIOp>(in, scaleD);
1586  val = b.create<arith::AddIOp>(val, offset);
1587  index = b.create<arith::FloorDivSIOp>(val, scaleN);
1588 
1589  // rx = x % scale_n
1590  // dx = rx / scale_n
1591  Value r = b.create<arith::RemSIOp>(val, scaleN);
1592  Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1593  Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1594  delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1595  };
1596 
1597  // Compute the ix and dx values for the X and Y dimensions - int case.
1598  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1599  Value scaleN, Value scaleD, Value offset,
1600  int size, ImplicitLocOpBuilder &b) {
1601  if (size == 1) {
1602  index = zeroI32;
1603  delta = zeroI32;
1604  return;
1605  }
1606  // x = x * scale_d + offset;
1607  // ix = floor(x / scale_n)
1608  // dx = x - ix * scale_n;
1609  Value val = b.create<arith::MulIOp>(in, scaleD);
1610  val = b.create<arith::AddIOp>(val, offset);
1611  index = b.create<arith::DivSIOp>(val, scaleN);
1612  delta = b.create<arith::MulIOp>(index, scaleN);
1613  delta = b.create<arith::SubIOp>(val, delta);
1614  };
1615 
1616  Value ix, iy, dx, dy;
1617  if (floatingPointMode) {
1618  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1619  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1620  } else {
1621  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1622  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1623  }
1624 
1625  if (op.getMode() == "NEAREST_NEIGHBOR") {
1626  auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1627 
1628  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1629  Value max, int size,
1630  ImplicitLocOpBuilder &b) -> Value {
1631  if (size == 1) {
1632  return b.create<arith::ConstantIndexOp>(0);
1633  }
1634 
1635  Value pred;
1636  if (floatingPointMode) {
1637  auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1638  pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1639  } else {
1640  Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1641  pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1642  dvalDouble, scale);
1643  }
1644 
1645  auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1646  val = b.create<arith::AddIOp>(val, offset);
1647  val = clampIntHelper(loc, val, zeroI32, max, b);
1648  return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1649  };
1650 
1651  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1652  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1653 
1654  Value result = b.create<tensor::ExtractOp>(
1655  input, ValueRange{batch, iy, ix, channel});
1656 
1657  b.create<linalg::YieldOp>(result);
1658  } else {
1659  // The mode here must be BILINEAR.
1660  assert(op.getMode() == "BILINEAR");
1661 
1662  auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1663 
1664  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1666  val0 = in;
1667  val1 = b.create<arith::AddIOp>(val0, oneVal);
1668  val0 = clampIntHelper(loc, val0, zeroI32, max, b);
1669  val1 = clampIntHelper(loc, val1, zeroI32, max, b);
1670  val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1671  val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1672  };
1673 
1674  // Linalg equivalent to the section below:
1675  // int16_t iy0 = apply_max(iy, 0);
1676  // int16_t iy1 = apply_min(iy + 1, IH - 1);
1677  // int16_t ix0 = apply_max(ix, 0);
1678  // int16_t ix1 = apply_min(ix + 1, IW - 1);
1679  Value x0, x1, y0, y1;
1680  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1681  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1682 
1683  Value y0x0 = b.create<tensor::ExtractOp>(
1684  input, ValueRange{batch, y0, x0, channel});
1685  Value y0x1 = b.create<tensor::ExtractOp>(
1686  input, ValueRange{batch, y0, x1, channel});
1687  Value y1x0 = b.create<tensor::ExtractOp>(
1688  input, ValueRange{batch, y1, x0, channel});
1689  Value y1x1 = b.create<tensor::ExtractOp>(
1690  input, ValueRange{batch, y1, x1, channel});
1691 
1692  if (floatingPointMode) {
1693  auto oneVal =
1694  b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1695  auto interpolate = [&](Value val0, Value val1, Value delta,
1696  int inputSize,
1697  ImplicitLocOpBuilder &b) -> Value {
1698  if (inputSize == 1)
1699  return val0;
1700  Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1701  Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1702  Value mul1 = b.create<arith::MulFOp>(val1, delta);
1703  return b.create<arith::AddFOp>(mul0, mul1);
1704  };
1705 
1706  // Linalg equivalent to the section below:
1707  // topAcc = v00 * (unit_x - dx);
1708  // topAcc += v01 * dx;
1709  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1710 
1711  // Linalg equivalent to the section below:
1712  // bottomAcc = v10 * (unit_x - dx);
1713  // bottomAcc += v11 * dx;
1714  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1715 
1716  // Linalg equivalent to the section below:
1717  // result = topAcc * (unit_y - dy) + bottomAcc * dy
1718  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1719  b.create<linalg::YieldOp>(result);
1720  } else {
1721  // Perform in quantized space.
1722  y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1723  y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1724  y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1725  y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1726 
1727  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1728  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1729  dx = b.create<arith::ExtSIOp>(resultETy, dx);
1730  dy = b.create<arith::ExtSIOp>(resultETy, dy);
1731  }
1732 
1733  Value yScaleNExt = yScaleN;
1734  Value xScaleNExt = xScaleN;
1735 
1736  const int64_t scaleBitwidth =
1737  xScaleN.getType().getIntOrFloatBitWidth();
1738  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1739  yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1740  xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1741  }
1742 
1743  auto interpolate = [](Value val0, Value val1, Value weight1,
1744  Value scale, int inputSize,
1745  ImplicitLocOpBuilder &b) -> Value {
1746  if (inputSize == 1)
1747  return b.create<arith::MulIOp>(val0, scale);
1748  Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1749  Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1750  Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1751  return b.create<arith::AddIOp>(mul0, mul1);
1752  };
1753 
1754  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1755  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1756  Value result =
1757  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1758  b.create<linalg::YieldOp>(result);
1759  }
1760  }
1761  }
1762 
1763  rewriter.replaceOp(op, resize);
1764  return success();
1765  }
1766 };
1767 
1768 // At the codegen level any identity operations should be removed. Any cases
1769 // where identity is load-bearing (e.g. cross device computation) should be
1770 // handled before lowering to codegen.
1771 template <typename SrcOp>
1772 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1773 public:
1775 
1776  LogicalResult matchAndRewrite(SrcOp op,
1777  PatternRewriter &rewriter) const final {
1778  rewriter.replaceOp(op, op.getOperation()->getOperands());
1779  return success();
1780  }
1781 };
1782 
1783 template <typename SrcOp>
1784 class ReduceConverter : public OpRewritePattern<SrcOp> {
1785 public:
1787 
1788  LogicalResult matchAndRewrite(SrcOp reduceOp,
1789  PatternRewriter &rewriter) const final {
1790  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1791  }
1792 };
1793 
1794 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1795 public:
1797 
1798  LogicalResult matchAndRewrite(tosa::ReverseOp op,
1799  PatternRewriter &rewriter) const final {
1800  auto loc = op.getLoc();
1801  Value input = op.getInput();
1802  auto inputTy = cast<ShapedType>(input.getType());
1803  auto resultTy = cast<ShapedType>(op.getType());
1804  auto axis = op.getAxis();
1805 
1806  SmallVector<Value> dynDims;
1807  for (int i = 0; i < inputTy.getRank(); i++) {
1808  if (inputTy.isDynamicDim(i)) {
1809  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1810  }
1811  }
1812 
1813  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1814 
1815  // First fill the output buffer with the init value.
1816  auto emptyTensor = rewriter
1817  .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1818  inputTy.getElementType(),
1819  ArrayRef<Value>({dynDims}))
1820  .getResult();
1821  SmallVector<AffineMap, 2> affineMaps = {
1822  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1823 
1824  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1825  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
1826  getNParallelLoopsAttrs(resultTy.getRank()),
1827  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1828  llvm::SmallVector<Value> indices;
1829  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1830  Value index =
1831  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1832  if (i == axis) {
1833  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1834  auto sizeMinusOne =
1835  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1836  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1837  index);
1838  }
1839 
1840  indices.push_back(index);
1841  }
1842 
1843  auto extract = nestedBuilder.create<tensor::ExtractOp>(
1844  nestedLoc, input, indices);
1845  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1846  extract.getResult());
1847  });
1848  return success();
1849  }
1850 };
1851 
1852 // This converter translate a tile operation to a reshape, broadcast, reshape.
1853 // The first reshape minimally expands each tiled dimension to include a
1854 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1855 // multiple.
1856 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1858 
1860  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1861  ConversionPatternRewriter &rewriter) const override {
1862  auto loc = op.getLoc();
1863  auto input = op.getInput1();
1864  auto inputTy = cast<ShapedType>(input.getType());
1865  auto inputShape = inputTy.getShape();
1866  auto resultTy = cast<ShapedType>(op.getType());
1867  auto elementTy = inputTy.getElementType();
1868  int64_t rank = inputTy.getRank();
1869 
1870  ArrayRef<int64_t> multiples = op.getMultiples();
1871 
1872  // Broadcast the newly added dimensions to their appropriate multiple.
1873  SmallVector<int64_t, 2> genericShape;
1874  for (int i = 0; i < rank; i++) {
1875  int64_t dim = multiples[i];
1876  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1877  genericShape.push_back(inputShape[i]);
1878  }
1879 
1880  SmallVector<Value> dynDims;
1881  for (int i = 0; i < inputTy.getRank(); i++) {
1882  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1883  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1884  }
1885  }
1886 
1887  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1888  op.getLoc(), genericShape, elementTy, dynDims);
1889 
1890  // We needs to map the input shape to the non-broadcasted dimensions.
1891  SmallVector<AffineExpr, 4> dimExprs;
1892  dimExprs.reserve(rank);
1893  for (unsigned i = 0; i < rank; ++i)
1894  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1895 
1896  auto readAffineMap =
1897  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1898  rewriter.getContext());
1899 
1900  SmallVector<AffineMap, 2> affineMaps = {
1901  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1902 
1903  auto genericOp = rewriter.create<linalg::GenericOp>(
1904  loc, RankedTensorType::get(genericShape, elementTy), input,
1905  ValueRange{emptyTensor}, affineMaps,
1906  getNParallelLoopsAttrs(genericShape.size()),
1907  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1908  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1909  });
1910 
1911  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1912  op, resultTy, genericOp.getResult(0),
1913  rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
1914  return success();
1915  }
1916 };
1917 
1918 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1919 // op, producing two output buffers.
1920 //
1921 // The first output buffer contains the index of the found maximum value. It is
1922 // initialized to 0 and is resulting integer type.
1923 //
1924 // The second output buffer contains the maximum value found. It is initialized
1925 // to the minimum representable value of the input element type. After being
1926 // populated by indexed_generic, this buffer is disgarded as only the index is
1927 // requested.
1928 //
1929 // The indexed_generic op updates both the maximum value and index if the
1930 // current value exceeds the running max.
1931 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1932 public:
1934 
1935  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1936  PatternRewriter &rewriter) const final {
1937  auto loc = argmaxOp.getLoc();
1938  Value input = argmaxOp.getInput();
1939  auto inputTy = cast<ShapedType>(input.getType());
1940  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1941  auto inElementTy = inputTy.getElementType();
1942  auto outElementTy = resultTy.getElementType();
1943  int axis = argmaxOp.getAxis();
1944  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1945 
1946  if (!isa<IntegerType>(outElementTy))
1947  return rewriter.notifyMatchFailure(
1948  argmaxOp,
1949  "tosa.arg_max to linalg.* requires integer-like result type");
1950 
1951  SmallVector<Value> dynDims;
1952  for (int i = 0; i < inputTy.getRank(); i++) {
1953  if (inputTy.isDynamicDim(i) && i != axis) {
1954  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1955  }
1956  }
1957 
1958  // First fill the output buffer for the index.
1959  auto emptyTensorIdx = rewriter
1960  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1961  outElementTy, dynDims)
1962  .getResult();
1963  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1964  loc, rewriter.getIntegerAttr(outElementTy, 0));
1965  auto filledTensorIdx =
1966  rewriter
1967  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
1968  ValueRange{emptyTensorIdx})
1969  .result();
1970 
1971  // Second fill the output buffer for the running max.
1972  auto emptyTensorMax = rewriter
1973  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1974  inElementTy, dynDims)
1975  .getResult();
1976  auto fillValueMaxAttr =
1977  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
1978 
1979  if (!fillValueMaxAttr)
1980  return rewriter.notifyMatchFailure(
1981  argmaxOp, "unsupported tosa.argmax element type");
1982 
1983  auto fillValueMax =
1984  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
1985  auto filledTensorMax =
1986  rewriter
1987  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
1988  ValueRange{emptyTensorMax})
1989  .result();
1990 
1991  // We need to reduce along the arg-max axis, with parallel operations along
1992  // the rest.
1994  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
1995  iteratorTypes[axis] = utils::IteratorType::reduction;
1996 
1997  SmallVector<AffineExpr, 2> srcExprs;
1998  SmallVector<AffineExpr, 2> dstExprs;
1999  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2000  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2001  if (axis != i)
2002  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2003  }
2004 
2005  bool didEncounterError = false;
2006  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2007  rewriter.getContext());
2008  auto linalgOp = rewriter.create<linalg::GenericOp>(
2009  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2010  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2011  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2012  ValueRange blockArgs) {
2013  auto newValue = blockArgs[0];
2014  auto oldIndex = blockArgs[1];
2015  auto oldValue = blockArgs[2];
2016 
2017  Value newIndex = rewriter.create<arith::IndexCastOp>(
2018  nestedLoc, oldIndex.getType(),
2019  rewriter.create<linalg::IndexOp>(loc, axis));
2020 
2021  Value predicate;
2022  if (isa<FloatType>(inElementTy)) {
2023  predicate = rewriter.create<arith::CmpFOp>(
2024  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2025  } else if (isa<IntegerType>(inElementTy)) {
2026  predicate = rewriter.create<arith::CmpIOp>(
2027  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2028  } else {
2029  didEncounterError = true;
2030  return;
2031  }
2032 
2033  auto resultMax = rewriter.create<arith::SelectOp>(
2034  nestedLoc, predicate, newValue, oldValue);
2035  auto resultIndex = rewriter.create<arith::SelectOp>(
2036  nestedLoc, predicate, newIndex, oldIndex);
2037  nestedBuilder.create<linalg::YieldOp>(
2038  nestedLoc, ValueRange({resultIndex, resultMax}));
2039  });
2040 
2041  if (didEncounterError)
2042  return rewriter.notifyMatchFailure(
2043  argmaxOp, "unsupported tosa.argmax element type");
2044 
2045  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2046  return success();
2047  }
2048 };
2049 
2050 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2051 public:
2054  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2055  ConversionPatternRewriter &rewriter) const final {
2056  auto input = adaptor.getOperands()[0];
2057  auto indices = adaptor.getOperands()[1];
2058 
2059  auto valuesTy =
2060  dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2061  auto resultTy = cast<ShapedType>(op.getType());
2062 
2063  if (!valuesTy)
2064  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2065 
2066  auto dynamicDims = inferDynamicDimsForGather(
2067  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2068 
2069  auto resultElementTy = resultTy.getElementType();
2070 
2071  auto loc = op.getLoc();
2072  auto emptyTensor =
2073  rewriter
2074  .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2075  dynamicDims)
2076  .getResult();
2077 
2078  SmallVector<AffineMap, 2> affineMaps = {
2080  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2081  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2082  rewriter.getContext()),
2083  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2084 
2085  auto genericOp = rewriter.create<linalg::GenericOp>(
2086  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2087  ValueRange{emptyTensor}, affineMaps,
2088  getNParallelLoopsAttrs(resultTy.getRank()),
2089  [&](OpBuilder &b, Location loc, ValueRange args) {
2090  auto indexValue = args[0];
2091  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2092  Value index1 = rewriter.create<arith::IndexCastOp>(
2093  loc, rewriter.getIndexType(), indexValue);
2094  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2095  Value extract = rewriter.create<tensor::ExtractOp>(
2096  loc, input, ValueRange{index0, index1, index2});
2097  rewriter.create<linalg::YieldOp>(loc, extract);
2098  });
2099  rewriter.replaceOp(op, genericOp.getResult(0));
2100  return success();
2101  }
2102 
2103  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2104  Location loc,
2105  Value values,
2106  Value indices) {
2107  llvm::SmallVector<Value> results;
2108 
2109  auto addDynamicDimension = [&](Value source, int64_t dim) {
2110  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2111  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2112  results.push_back(dimValue);
2113  };
2114 
2115  addDynamicDimension(values, 0);
2116  addDynamicDimension(indices, 1);
2117  addDynamicDimension(values, 2);
2118  return results;
2119  }
2120 };
2121 
2122 // Lowerings the TableOp to a series of gathers and numerica operations. This
2123 // includes interpolation between the high/low values. For the I8 varient, this
2124 // simplifies to a single gather operation.
2125 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2126 public:
2128 
2129  LogicalResult matchAndRewrite(tosa::TableOp op,
2130  PatternRewriter &rewriter) const final {
2131  auto loc = op.getLoc();
2132  Value input = op.getInput();
2133  Value table = op.getTable();
2134  auto inputTy = cast<ShapedType>(input.getType());
2135  auto tableTy = cast<ShapedType>(table.getType());
2136  auto resultTy = cast<ShapedType>(op.getType());
2137 
2138  auto inputElementTy = inputTy.getElementType();
2139  auto tableElementTy = tableTy.getElementType();
2140  auto resultElementTy = resultTy.getElementType();
2141 
2142  SmallVector<Value> dynDims;
2143  for (int i = 0; i < resultTy.getRank(); ++i) {
2144  if (inputTy.isDynamicDim(i)) {
2145  dynDims.push_back(
2146  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2147  }
2148  }
2149 
2150  auto emptyTensor = rewriter
2151  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2152  resultElementTy, dynDims)
2153  .getResult();
2154 
2155  SmallVector<AffineMap, 2> affineMaps = {
2156  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2157  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2158 
2159  auto genericOp = rewriter.create<linalg::GenericOp>(
2160  loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2161  getNParallelLoopsAttrs(resultTy.getRank()));
2162  rewriter.replaceOp(op, genericOp.getResult(0));
2163 
2164  {
2165  OpBuilder::InsertionGuard regionGuard(rewriter);
2166  Block *block = rewriter.createBlock(
2167  &genericOp.getRegion(), genericOp.getRegion().end(),
2168  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2169 
2170  auto inputValue = block->getArgument(0);
2171  rewriter.setInsertionPointToStart(block);
2172  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2173  resultElementTy.isInteger(8)) {
2174  Value index = rewriter.create<arith::IndexCastOp>(
2175  loc, rewriter.getIndexType(), inputValue);
2176  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2177  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2178  index, offset);
2179  Value extract =
2180  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2181  rewriter.create<linalg::YieldOp>(loc, extract);
2182  return success();
2183  }
2184 
2185  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2186  resultElementTy.isInteger(32)) {
2187  Value extend = rewriter.create<arith::ExtSIOp>(
2188  loc, rewriter.getI32Type(), inputValue);
2189 
2190  auto offset = rewriter.create<arith::ConstantOp>(
2191  loc, rewriter.getI32IntegerAttr(32768));
2192  auto seven = rewriter.create<arith::ConstantOp>(
2193  loc, rewriter.getI32IntegerAttr(7));
2194  auto one = rewriter.create<arith::ConstantOp>(
2195  loc, rewriter.getI32IntegerAttr(1));
2196  auto b1111111 = rewriter.create<arith::ConstantOp>(
2197  loc, rewriter.getI32IntegerAttr(127));
2198 
2199  // Compute the index and fractional part from the input value:
2200  // value = value + 32768
2201  // index = value >> 7;
2202  // fraction = 0x01111111 & value
2203  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2204  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2205  Value fraction =
2206  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2207 
2208  // Extract the base and next values from the table.
2209  // base = (int32_t) table[index];
2210  // next = (int32_t) table[index + 1];
2211  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2212 
2213  index = rewriter.create<arith::IndexCastOp>(
2214  loc, rewriter.getIndexType(), index);
2215  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2216  loc, rewriter.getIndexType(), indexPlusOne);
2217 
2218  Value base =
2219  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2220  Value next = rewriter.create<tensor::ExtractOp>(
2221  loc, table, ValueRange{indexPlusOne});
2222 
2223  base =
2224  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2225  next =
2226  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2227 
2228  // Use the fractional part to interpolate between the input values:
2229  // result = (base << 7) + (next - base) * fraction
2230  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2231  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2232  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2233  Value result =
2234  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2235 
2236  rewriter.create<linalg::YieldOp>(loc, result);
2237 
2238  return success();
2239  }
2240  }
2241 
2242  return rewriter.notifyMatchFailure(
2243  op, "unable to create body for tosa.table op");
2244  }
2245 };
2246 
2247 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2249 
2250  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2251 
2252  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2253  OpFoldResult ofr) {
2254  auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2255  auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2256 
2257  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2258  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2259  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2260  return getAsOpFoldResult(plusOne);
2261  }
2262 
2263  static RankedTensorType
2264  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2265  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2266  // Get [N, H, W]
2267  auto dims = tensor::getMixedSizes(builder, loc, input);
2268 
2269  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2270  // output tensors.
2271  dims[2] = halfPlusOne(builder, loc, dims[2]);
2272 
2273  llvm::SmallVector<int64_t, 3> staticSizes;
2274  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2275 
2276  auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2277  return RankedTensorType::get(staticSizes, elementType);
2278  }
2279 
2280  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2281  RankedTensorType type,
2282  llvm::ArrayRef<Value> dynamicSizes) {
2283  auto emptyTensor =
2284  rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2285  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2286  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2287  auto filledTensor = rewriter
2288  .create<linalg::FillOp>(loc, ValueRange{fillValue},
2289  ValueRange{emptyTensor})
2290  .result();
2291  return filledTensor;
2292  }
2293 
2294  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2295  FloatType type, Value value) {
2296  auto integerVal = builder.create<arith::IndexCastUIOp>(
2297  loc,
2298  type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2299  : builder.getI32Type(),
2300  value);
2301 
2302  return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2303  }
2304 
2305  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2306  FloatType type, int64_t index) {
2307  auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2308  return castIndexToFloat(builder, loc, type, indexVal);
2309  }
2310 
2311  template <typename... Args>
2312  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2313  Args... args) {
2314  return {builder.getAffineDimExpr(args)...};
2315  }
2316 
2317  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2318  PatternRewriter &rewriter) const override {
2319  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2320  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2321  return rewriter.notifyMatchFailure(rfft2d,
2322  "only supports ranked tensors");
2323  }
2324 
2325  auto loc = rfft2d.getLoc();
2326  auto input = rfft2d.getInput();
2327  auto elementType =
2328  cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2329 
2330  // Compute the output type and set of dynamic sizes
2331  llvm::SmallVector<Value> dynamicSizes;
2332  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2333 
2334  // Iterator types for the linalg.generic implementation
2336  utils::IteratorType::parallel, utils::IteratorType::parallel,
2337  utils::IteratorType::parallel, utils::IteratorType::reduction,
2338  utils::IteratorType::reduction};
2339 
2340  // Inputs/outputs to the linalg.generic implementation
2341  llvm::SmallVector<Value> genericOpInputs = {input};
2342  llvm::SmallVector<Value> genericOpOutputs = {
2343  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2344  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2345 
2346  // Indexing maps for input and output tensors
2347  auto indexingMaps = AffineMap::inferFromExprList(
2348  llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2349  affineDimsExpr(rewriter, 0, 1, 2),
2350  affineDimsExpr(rewriter, 0, 1, 2)},
2351  rewriter.getContext());
2352 
2353  // Width and height dimensions of the original input.
2354  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2355  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2356 
2357  // Constants and dimension sizes
2358  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2359  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2360  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2361  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2362 
2363  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2364  Value valReal = args[0];
2365  Value sumReal = args[1];
2366  Value sumImag = args[2];
2367 
2368  // Indices for angle computation
2369  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2370  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2371  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2372  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2373 
2374  // Calculating angle without integer parts of components as sin/cos are
2375  // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2376  // / W);
2377  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2378  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2379 
2380  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2381  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2382 
2383  auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2384  auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2385 
2386  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2387  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2388  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2389  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2390 
2391  // realComponent = valReal * cos(angle)
2392  // imagComponent = valReal * sin(angle)
2393  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2394  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2395  auto realComponent =
2396  builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2397  auto imagComponent =
2398  builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2399 
2400  // outReal = sumReal + realComponent
2401  // outImag = sumImag - imagComponent
2402  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2403  auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2404 
2405  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2406  };
2407 
2408  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2409  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2410  indexingMaps, iteratorTypes, buildBody);
2411 
2412  return success();
2413  }
2414 };
2415 
2416 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2418 
2419  LogicalResult matchAndRewrite(FFT2dOp fft2d,
2420  PatternRewriter &rewriter) const override {
2421  if (!llvm::all_of(fft2d->getOperandTypes(),
2422  RFFT2dConverter::isRankedTensor) ||
2423  !llvm::all_of(fft2d->getResultTypes(),
2424  RFFT2dConverter::isRankedTensor)) {
2425  return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2426  }
2427 
2428  Location loc = fft2d.getLoc();
2429  Value input_real = fft2d.getInputReal();
2430  Value input_imag = fft2d.getInputImag();
2431  BoolAttr inverse = fft2d.getInverseAttr();
2432 
2433  auto real_el_ty = cast<FloatType>(
2434  cast<ShapedType>(input_real.getType()).getElementType());
2435  [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2436  cast<ShapedType>(input_imag.getType()).getElementType());
2437 
2438  assert(real_el_ty == imag_el_ty);
2439 
2440  // Compute the output type and set of dynamic sizes
2441  SmallVector<Value> dynamicSizes;
2442 
2443  // Get [N, H, W]
2444  auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2445 
2446  SmallVector<int64_t, 3> staticSizes;
2447  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2448 
2449  auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2450 
2451  // Iterator types for the linalg.generic implementation
2452  SmallVector<utils::IteratorType, 5> iteratorTypes = {
2453  utils::IteratorType::parallel, utils::IteratorType::parallel,
2454  utils::IteratorType::parallel, utils::IteratorType::reduction,
2455  utils::IteratorType::reduction};
2456 
2457  // Inputs/outputs to the linalg.generic implementation
2458  SmallVector<Value> genericOpInputs = {input_real, input_imag};
2459  SmallVector<Value> genericOpOutputs = {
2460  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2461  dynamicSizes),
2462  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2463  dynamicSizes)};
2464 
2465  // Indexing maps for input and output tensors
2466  auto indexingMaps = AffineMap::inferFromExprList(
2467  ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2468  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2469  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2470  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2471  rewriter.getContext());
2472 
2473  // Width and height dimensions of the original input.
2474  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2475  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2476 
2477  // Constants and dimension sizes
2478  auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2479  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2480  Value constH =
2481  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2482  Value constW =
2483  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2484 
2485  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2486  Value valReal = args[0];
2487  Value valImag = args[1];
2488  Value sumReal = args[2];
2489  Value sumImag = args[3];
2490 
2491  // Indices for angle computation
2492  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2493  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2494  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2495  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2496 
2497  // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2498  // ox) % W ) / W);
2499  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2500  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2501 
2502  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2503  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2504 
2505  auto iyRemFloat =
2506  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2507  auto ixRemFloat =
2508  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2509 
2510  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2511  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2512 
2513  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2514  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2515 
2516  if (inverse.getValue()) {
2517  angle = builder.create<arith::MulFOp>(
2518  loc, angle,
2519  rewriter.create<arith::ConstantOp>(
2520  loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2521  }
2522 
2523  // realComponent = val_real * cos(a) + val_imag * sin(a);
2524  // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2525  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2526  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2527 
2528  auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2529  auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2530  auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2531 
2532  auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2533  auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2534 
2535  auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2536 
2537  // outReal = sumReal + realComponent
2538  // outImag = sumImag - imagComponent
2539  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2540  auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2541 
2542  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2543  };
2544 
2545  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2546  fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2547  indexingMaps, iteratorTypes, buildBody);
2548 
2549  return success();
2550  }
2551 };
2552 
2553 } // namespace
2554 
2556  RewritePatternSet *patterns) {
2557 
2558  // We have multiple resize coverters to handle degenerate cases.
2559  patterns->add<GenericResizeConverter>(patterns->getContext(),
2560  /*benefit=*/100);
2561  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2562  /*benefit=*/200);
2563  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2564  /*benefit=*/300);
2565 
2566  patterns->add<
2567  // clang-format off
2568  PointwiseConverter<tosa::AddOp>,
2569  PointwiseConverter<tosa::SubOp>,
2570  PointwiseConverter<tosa::MulOp>,
2571  PointwiseConverter<tosa::DivOp>,
2572  PointwiseConverter<tosa::NegateOp>,
2573  PointwiseConverter<tosa::PowOp>,
2574  PointwiseConverter<tosa::ReciprocalOp>,
2575  PointwiseConverter<tosa::RsqrtOp>,
2576  PointwiseConverter<tosa::LogOp>,
2577  PointwiseConverter<tosa::ExpOp>,
2578  PointwiseConverter<tosa::AbsOp>,
2579  PointwiseConverter<tosa::TanhOp>,
2580  PointwiseConverter<tosa::ErfOp>,
2581  PointwiseConverter<tosa::BitwiseAndOp>,
2582  PointwiseConverter<tosa::BitwiseOrOp>,
2583  PointwiseConverter<tosa::BitwiseNotOp>,
2584  PointwiseConverter<tosa::BitwiseXorOp>,
2585  PointwiseConverter<tosa::LogicalAndOp>,
2586  PointwiseConverter<tosa::LogicalNotOp>,
2587  PointwiseConverter<tosa::LogicalOrOp>,
2588  PointwiseConverter<tosa::LogicalXorOp>,
2589  PointwiseConverter<tosa::CastOp>,
2590  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2591  PointwiseConverter<tosa::LogicalRightShiftOp>,
2592  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2593  PointwiseConverter<tosa::ClzOp>,
2594  PointwiseConverter<tosa::SelectOp>,
2595  PointwiseConverter<tosa::GreaterOp>,
2596  PointwiseConverter<tosa::GreaterEqualOp>,
2597  PointwiseConverter<tosa::EqualOp>,
2598  PointwiseConverter<tosa::MaximumOp>,
2599  PointwiseConverter<tosa::MinimumOp>,
2600  PointwiseConverter<tosa::CeilOp>,
2601  PointwiseConverter<tosa::FloorOp>,
2602  PointwiseConverter<tosa::ClampOp>,
2603  PointwiseConverter<tosa::SigmoidOp>,
2604  IdentityNConverter<tosa::IdentityOp>,
2605  ReduceConverter<tosa::ReduceAllOp>,
2606  ReduceConverter<tosa::ReduceAnyOp>,
2607  ReduceConverter<tosa::ReduceMinOp>,
2608  ReduceConverter<tosa::ReduceMaxOp>,
2609  ReduceConverter<tosa::ReduceSumOp>,
2610  ReduceConverter<tosa::ReduceProdOp>,
2611  ArgMaxConverter,
2612  GatherConverter,
2613  RescaleConverter,
2614  ReverseConverter,
2615  RFFT2dConverter,
2616  FFT2dConverter,
2617  TableConverter,
2618  TileConverter>(patterns->getContext());
2619  // clang-format on
2620 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, PatternRewriter &rewriter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static LogicalResult emitElementwiseComputation(PatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, Operation *operation)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:296
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:379
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:91
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
MPInt round(const Fraction &f)
Definition: Fraction.h:133
Fraction abs(const Fraction &f)
Definition: Fraction.h:104
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:51
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362