MLIR  20.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 
34 #include <numeric>
35 
36 using namespace mlir;
37 using namespace mlir::tosa;
38 
39 template <typename T>
40 static arith::ConstantOp
41 createConstFromIntAttribute(Operation *op, const std::string &attrName,
42  Type requiredAttrType, OpBuilder &rewriter) {
43  auto castedN = static_cast<T>(
44  cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
45  return rewriter.create<arith::ConstantOp>(
46  op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
47 }
48 
50  Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51  ConversionPatternRewriter &rewriter) {
52  Location loc = op->getLoc();
53  auto elementTy =
54  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
55 
56  // tosa::AbsOp
57  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
58  return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
59 
60  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
61  auto zero = rewriter.create<arith::ConstantOp>(
62  loc, rewriter.getZeroAttr(elementTy));
63  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
64  return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
65  }
66 
67  // tosa::AddOp
68  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
69  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
70 
71  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
72  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
73 
74  // tosa::SubOp
75  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
76  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
77 
78  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
79  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
80 
81  // tosa::IntDivOp
82  if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
83  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
84 
85  // tosa::ReciprocalOp
86  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
87  auto one =
88  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
89  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
90  }
91 
92  // tosa::MulOp
93  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
94  return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
95 
96  if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
97  Value a = args[0];
98  Value b = args[1];
99  auto shift =
100  cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
101  if (shift > 0) {
102  auto shiftConst =
103  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
104  if (!a.getType().isInteger(32))
105  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
106 
107  if (!b.getType().isInteger(32))
108  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
109 
110  auto result = rewriter.create<tosa::ApplyScaleOp>(
111  loc, rewriter.getI32Type(), a, b, shiftConst,
112  rewriter.getBoolAttr(false));
113 
114  if (elementTy.isInteger(32))
115  return result;
116 
117  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
118  }
119 
120  int aWidth = a.getType().getIntOrFloatBitWidth();
121  int bWidth = b.getType().getIntOrFloatBitWidth();
122  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
123 
124  if (aWidth < cWidth)
125  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
126  if (bWidth < cWidth)
127  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
128 
129  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
130  }
131 
132  // tosa::NegateOp
133  if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
134  return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
135 
136  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
137  int64_t inZp = 0, outZp = 0;
138 
139  if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
140  auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
141  inZp = quantizationInfo.value().getInputZp();
142  outZp = quantizationInfo.value().getOutputZp();
143  }
144 
145  int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
146  if (!inZp && !outZp) {
147  auto constant = rewriter.create<arith::ConstantOp>(
148  loc, IntegerAttr::get(elementTy, 0));
149  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
150  args[0]);
151  }
152 
153  // Compute the maximum value that can occur in the intermediate buffer.
154  int64_t zpAdd = inZp + outZp;
155  int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
156  std::abs(zpAdd) + 1;
157 
158  // Convert that maximum value into the maximum bitwidth needed to represent
159  // it. We assume 48-bit numbers may be supported further in the pipeline.
160  int intermediateBitWidth = 64;
161  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
162  intermediateBitWidth = 16;
163  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
164  intermediateBitWidth = 32;
165  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
166  intermediateBitWidth = 48;
167  }
168 
169  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
170  Value zpAddValue = rewriter.create<arith::ConstantOp>(
171  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
172 
173  // The negation can be applied by doing:
174  // outputValue = inZp + outZp - inputValue
175  auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
176  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
177 
178  // Clamp to the negation range.
179  Value min = rewriter.create<arith::ConstantIntOp>(
180  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
181  intermediateType);
182  Value max = rewriter.create<arith::ConstantIntOp>(
183  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
184  intermediateType);
185  auto clamp =
186  clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
187 
188  // Truncate to the final value.
189  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
190  }
191 
192  // tosa::BitwiseAndOp
193  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
194  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
195 
196  // tosa::BitwiseOrOp
197  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
198  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
199 
200  // tosa::BitwiseNotOp
201  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
202  auto allOnesAttr = rewriter.getIntegerAttr(
203  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
204  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
205  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
206  }
207 
208  // tosa::BitwiseXOrOp
209  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
210  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
211 
212  // tosa::LogicalLeftShiftOp
213  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
214  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
215 
216  // tosa::LogicalRightShiftOp
217  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
218  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
219 
220  // tosa::ArithmeticRightShiftOp
221  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
222  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
223  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
224  if (!round) {
225  return result;
226  }
227 
228  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
229  auto one =
230  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
231  auto zero =
232  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
233  auto i1one =
234  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
235 
236  // Checking that input2 != 0
237  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
238  loc, arith::CmpIPredicate::sgt, args[1], zero);
239 
240  // Checking for the last bit of input1 to be 1
241  auto subtract =
242  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
243  auto shifted =
244  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
245  ->getResults();
246  auto truncated =
247  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
248  auto isInputOdd =
249  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
250 
251  auto shouldRound = rewriter.create<arith::AndIOp>(
252  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
253  auto extended =
254  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
255  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
256  }
257 
258  // tosa::ClzOp
259  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
260  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
261  }
262 
263  // tosa::LogicalAnd
264  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
265  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
266 
267  // tosa::LogicalNot
268  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
269  auto one = rewriter.create<arith::ConstantOp>(
270  loc, rewriter.getIntegerAttr(elementTy, 1));
271  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
272  }
273 
274  // tosa::LogicalOr
275  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
276  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
277 
278  // tosa::LogicalXor
279  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
280  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
281 
282  // tosa::PowOp
283  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
284  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
285 
286  // tosa::RsqrtOp
287  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
288  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
289 
290  // tosa::LogOp
291  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
292  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
293 
294  // tosa::ExpOp
295  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
296  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
297 
298  // tosa::SinOp
299  if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
300  return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
301 
302  // tosa::CosOp
303  if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
304  return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
305 
306  // tosa::TanhOp
307  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
308  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
309 
310  // tosa::ErfOp
311  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
312  return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
313 
314  // tosa::GreaterOp
315  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
316  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
317  args[0], args[1]);
318 
319  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
320  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
321  args[0], args[1]);
322 
323  // tosa::GreaterEqualOp
324  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
325  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
326  args[0], args[1]);
327 
328  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
329  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
330  args[0], args[1]);
331 
332  // tosa::EqualOp
333  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
334  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
335  args[0], args[1]);
336 
337  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
338  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
339  args[0], args[1]);
340 
341  // tosa::SelectOp
342  if (isa<tosa::SelectOp>(op)) {
343  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
344  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
345  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
346  }
347 
348  // tosa::MaximumOp
349  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
350  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
351  }
352 
353  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
354  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
355  }
356 
357  // tosa::MinimumOp
358  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
359  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
360  }
361 
362  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
363  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
364  }
365 
366  // tosa::CeilOp
367  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
368  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
369 
370  // tosa::FloorOp
371  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
372  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
373 
374  // tosa::ClampOp
375  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
376  bool losesInfo = false;
377  APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
378  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
379  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
380  APFloat::rmNearestTiesToEven, &losesInfo);
381  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
382  APFloat::rmNearestTiesToEven, &losesInfo);
383  auto min = rewriter.create<arith::ConstantOp>(
384  loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
385  auto max = rewriter.create<arith::ConstantOp>(
386  loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
387  return clampFloatHelper(loc, args[0], min, max, rewriter);
388  }
389 
390  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
391  auto intTy = cast<IntegerType>(elementTy);
392  int64_t min =
393  cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
394  int64_t max =
395  cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
396 
397  int64_t minRepresentable = std::numeric_limits<int64_t>::min();
398  int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
399  if (intTy.isUnsignedInteger()) {
400  minRepresentable = 0;
401  if (intTy.getIntOrFloatBitWidth() <= 63) {
402  maxRepresentable =
403  (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
404  .getZExtValue();
405  }
406  } else if (intTy.getIntOrFloatBitWidth() <= 64) {
407  // Ensure that min & max fit into signed n-bit constants.
408  minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
409  .getSExtValue();
410  maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
411  .getSExtValue();
412  }
413  // Ensure that the bounds are representable as n-bit signed/unsigned
414  // integers.
415  min = std::max(min, minRepresentable);
416  max = std::max(max, minRepresentable);
417  min = std::min(min, maxRepresentable);
418  max = std::min(max, maxRepresentable);
419 
420  auto minVal = rewriter.create<arith::ConstantIntOp>(
421  loc, min, intTy.getIntOrFloatBitWidth());
422  auto maxVal = rewriter.create<arith::ConstantIntOp>(
423  loc, max, intTy.getIntOrFloatBitWidth());
424  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
425  intTy.isUnsignedInteger());
426  }
427 
428  // tosa::SigmoidOp
429  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
430  auto one =
431  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
432  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
433  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
434  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
435  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
436  }
437 
438  // tosa::CastOp
439  if (isa<tosa::CastOp>(op)) {
440  Type srcTy = elementTy;
441  Type dstTy = resultTypes.front();
442  bool bitExtend =
444 
445  if (srcTy == dstTy)
446  return args.front();
447 
448  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
449  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
450  std::nullopt);
451 
452  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
453  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
454  std::nullopt);
455 
456  // 1-bit integers need to be treated as signless.
457  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
458  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
459  std::nullopt);
460 
461  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
462  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
463  std::nullopt);
464 
465  // Unsigned integers need an unrealized cast so that they can be passed
466  // to UIToFP.
467  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
468  auto unrealizedCast =
469  rewriter
470  .create<UnrealizedConversionCastOp>(
471  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
472  args[0])
473  .getResult(0);
474  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
475  unrealizedCast);
476  }
477 
478  // All other si-to-fp conversions should be handled by SIToFP.
479  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
480  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
481  std::nullopt);
482 
483  // Casting to boolean, floats need to only be checked as not-equal to zero.
484  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
485  Value zero = rewriter.create<arith::ConstantOp>(
486  loc, rewriter.getFloatAttr(srcTy, 0.0));
487  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
488  args.front(), zero);
489  }
490 
491  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
492  auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
493 
494  const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
495  // Check whether neither int min nor int max can be represented in the
496  // input floating-point type due to too short exponent range.
497  if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
498  APFloat::semanticsMaxExponent(fltSemantics)) {
499  // Use cmp + select to replace infinites by int min / int max. Other
500  // integral values can be represented in the integer space.
501  auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
502  auto posInf = rewriter.create<arith::ConstantOp>(
503  loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
504  APFloat::getInf(fltSemantics)));
505  auto negInf = rewriter.create<arith::ConstantOp>(
506  loc, rewriter.getFloatAttr(
507  getElementTypeOrSelf(srcTy),
508  APFloat::getInf(fltSemantics, /*Negative=*/true)));
509  auto overflow = rewriter.create<arith::CmpFOp>(
510  loc, arith::CmpFPredicate::UEQ, rounded, posInf);
511  auto underflow = rewriter.create<arith::CmpFOp>(
512  loc, arith::CmpFPredicate::UEQ, rounded, negInf);
513  auto intMin = rewriter.create<arith::ConstantOp>(
514  loc, rewriter.getIntegerAttr(
515  getElementTypeOrSelf(dstTy),
516  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
517  auto intMax = rewriter.create<arith::ConstantOp>(
518  loc, rewriter.getIntegerAttr(
519  getElementTypeOrSelf(dstTy),
520  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
521  auto maxClamped =
522  rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
523  return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
524  maxClamped);
525  }
526 
527  auto intMinFP = rewriter.create<arith::ConstantOp>(
528  loc, rewriter.getFloatAttr(
529  getElementTypeOrSelf(srcTy),
530  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
531  .getSExtValue()));
532 
533  // Check whether the mantissa has enough bits to represent int max.
534  if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
535  dstTy.getIntOrFloatBitWidth() - 1) {
536  // Int min can also be represented since it is a power of two and thus
537  // consists of a single leading bit. Therefore we can clamp the input
538  // in the floating-point domain.
539 
540  auto intMaxFP = rewriter.create<arith::ConstantOp>(
541  loc, rewriter.getFloatAttr(
542  getElementTypeOrSelf(srcTy),
543  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
544  .getSExtValue()));
545 
546  Value clamped =
547  clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
548  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
549  }
550 
551  // Due to earlier check we know exponant range is big enough to represent
552  // int min. We can therefore rely on int max + 1 being representable as
553  // well because it's just int min with a positive sign. So clamp the min
554  // value and compare against that to select the max int value if needed.
555  auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
556  loc, rewriter.getFloatAttr(
557  getElementTypeOrSelf(srcTy),
558  static_cast<double>(
559  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
560  .getSExtValue()) +
561  1.0f));
562 
563  auto intMax = rewriter.create<arith::ConstantOp>(
564  loc, rewriter.getIntegerAttr(
565  getElementTypeOrSelf(dstTy),
566  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
567  auto minClampedFP =
568  rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
569  auto minClamped =
570  rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
571  auto overflow = rewriter.create<arith::CmpFOp>(
572  loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
573  return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
574  minClamped);
575  }
576 
577  // Casting to boolean, integers need to only be checked as not-equal to
578  // zero.
579  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
580  Value zero = rewriter.create<arith::ConstantIntOp>(
581  loc, 0, srcTy.getIntOrFloatBitWidth());
582  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
583  args.front(), zero);
584  }
585 
586  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
587  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
588  std::nullopt);
589 
590  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
591  return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
592  }
593  }
594 
595  (void)rewriter.notifyMatchFailure(
596  op, "unhandled op for linalg body calculation for elementwise op");
597  return nullptr;
598 }
599 
600 static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
601  int64_t rank) {
602  // No need to expand if we are already at the desired rank
603  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
604  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
605  int64_t numExtraDims = rank - shapedType.getRank();
606  assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
607  if (!numExtraDims)
608  return tensor;
609 
610  // Compute reassociation indices
611  SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
612  shapedType.getRank());
613  int64_t index = 0;
614  for (index = 0; index <= numExtraDims; index++)
615  reassociationIndices[0].push_back(index);
616  for (size_t position = 1; position < reassociationIndices.size(); position++)
617  reassociationIndices[position].push_back(index++);
618 
619  // Compute result type
620  SmallVector<int64_t> resultShape;
621  for (index = 0; index < numExtraDims; index++)
622  resultShape.push_back(1);
623  for (auto size : shapedType.getShape())
624  resultShape.push_back(size);
625  auto resultType =
626  RankedTensorType::get(resultShape, shapedType.getElementType());
627 
628  // Emit 'tensor.expand_shape' op
629  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
630  reassociationIndices);
631 }
632 
634  Location loc, ValueRange operands,
635  int64_t rank) {
636  return llvm::map_to_vector(operands, [&](Value operand) {
637  return expandRank(rewriter, loc, operand, rank);
638  });
639 }
640 
642 
643 // Emit an 'arith.constant' op for the given index if it has not been created
644 // yet, or return an existing constant. This will prevent an excessive creation
645 // of redundant constants, easing readability of emitted code for unit tests.
647  IndexPool &indexPool, int64_t index) {
648  auto [it, inserted] = indexPool.try_emplace(index);
649  if (inserted)
650  it->second =
651  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
652  return it->second;
653 }
654 
656  IndexPool &indexPool, Value tensor, int64_t index) {
657  auto indexValue = createIndex(rewriter, loc, indexPool, index);
658  return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
659 }
660 
662  IndexPool &indexPool, Value tensor,
663  int64_t index) {
664  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
665  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
666  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
667  if (shapedType.isDynamicDim(index))
668  return getTensorDim(rewriter, loc, indexPool, tensor, index);
669  return rewriter.getIndexAttr(shapedType.getDimSize(index));
670 }
671 
672 static bool operandsAndResultsRanked(Operation *operation) {
673  auto isRanked = [](Value value) {
674  return isa<RankedTensorType>(value.getType());
675  };
676  return llvm::all_of(operation->getOperands(), isRanked) &&
677  llvm::all_of(operation->getResults(), isRanked);
678 }
679 
680 // Compute the runtime dimension size for dimension 'dim' of the output by
681 // inspecting input 'operands', all of which are expected to have the same rank.
682 // This function returns a pair {targetSize, masterOperand}.
683 //
684 // The runtime size of the output dimension is returned either as a statically
685 // computed attribute or as a runtime SSA value.
686 //
687 // If the target size was inferred directly from one dominating operand, that
688 // operand is returned in 'masterOperand'. If the target size is inferred from
689 // multiple operands, 'masterOperand' is set to nullptr.
690 static std::pair<OpFoldResult, Value>
692  ValueRange operands, int64_t dim) {
693  // If any input operand contains a static size greater than 1 for this
694  // dimension, that is the target size. An occurrence of an additional static
695  // dimension greater than 1 with a different value is undefined behavior.
696  for (auto operand : operands) {
697  auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
698  if (!ShapedType::isDynamic(size) && size > 1)
699  return {rewriter.getIndexAttr(size), operand};
700  }
701 
702  // Filter operands with dynamic dimension
703  auto operandsWithDynamicDim =
704  llvm::filter_to_vector(operands, [&](Value operand) {
705  return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
706  });
707 
708  // If no operand has a dynamic dimension, it means all sizes were 1
709  if (operandsWithDynamicDim.empty())
710  return {rewriter.getIndexAttr(1), operands.front()};
711 
712  // Emit code that computes the runtime size for this dimension. If there is
713  // only one operand with a dynamic dimension, it is considered the master
714  // operand that determines the runtime size of the output dimension.
715  auto targetSize =
716  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
717  if (operandsWithDynamicDim.size() == 1)
718  return {targetSize, operandsWithDynamicDim[0]};
719 
720  // Calculate maximum size among all dynamic dimensions
721  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
722  auto nextSize =
723  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
724  targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
725  }
726  return {targetSize, nullptr};
727 }
728 
729 // Compute the runtime output size for all dimensions. This function returns
730 // a pair {targetShape, masterOperands}.
731 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
733  IndexPool &indexPool, ValueRange operands) {
734  assert(!operands.empty());
735  auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
736  SmallVector<OpFoldResult> targetShape;
737  SmallVector<Value> masterOperands;
738  for (auto dim : llvm::seq<int64_t>(0, rank)) {
739  auto [targetSize, masterOperand] =
740  computeTargetSize(rewriter, loc, indexPool, operands, dim);
741  targetShape.push_back(targetSize);
742  masterOperands.push_back(masterOperand);
743  }
744  return {targetShape, masterOperands};
745 }
746 
748  IndexPool &indexPool, Value operand,
749  int64_t dim, OpFoldResult targetSize,
750  Value masterOperand) {
751  // Nothing to do if this is a static dimension
752  auto rankedTensorType = cast<RankedTensorType>(operand.getType());
753  if (!rankedTensorType.isDynamicDim(dim))
754  return operand;
755 
756  // If the target size for this dimension was directly inferred by only taking
757  // this operand into account, there is no need to broadcast. This is an
758  // optimization that will prevent redundant control flow, and constitutes the
759  // main motivation for tracking "master operands".
760  if (operand == masterOperand)
761  return operand;
762 
763  // Affine maps for 'linalg.generic' op
764  auto rank = rankedTensorType.getRank();
765  SmallVector<AffineExpr> affineExprs;
766  for (auto index : llvm::seq<int64_t>(0, rank)) {
767  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
768  : rewriter.getAffineDimExpr(index);
769  affineExprs.push_back(affineExpr);
770  }
771  auto broadcastAffineMap =
772  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
773  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
774  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
775 
776  // Check if broadcast is necessary
777  auto one = createIndex(rewriter, loc, indexPool, 1);
778  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
779  auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
780  loc, arith::CmpIPredicate::eq, runtimeSize, one);
781 
782  // Emit 'then' region of 'scf.if'
783  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
784  // It is not safe to cache constants across regions.
785  // New constants could potentially violate dominance requirements.
786  IndexPool localPool;
787 
788  // Emit 'tensor.empty' op
789  SmallVector<OpFoldResult> outputTensorShape;
790  for (auto index : llvm::seq<int64_t>(0, rank)) {
791  auto size = index == dim ? targetSize
792  : getOrFoldTensorDim(rewriter, loc, localPool,
793  operand, index);
794  outputTensorShape.push_back(size);
795  }
796  Value outputTensor = opBuilder.create<tensor::EmptyOp>(
797  loc, outputTensorShape, rankedTensorType.getElementType());
798 
799  // Emit 'linalg.generic' op
800  auto resultTensor =
801  opBuilder
802  .create<linalg::GenericOp>(
803  loc, outputTensor.getType(), operand, outputTensor, affineMaps,
805  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
806  // Emit 'linalg.yield' op
807  opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
808  })
809  .getResult(0);
810 
811  // Cast to original operand type if necessary
812  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
813  loc, operand.getType(), resultTensor);
814 
815  // Emit 'scf.yield' op
816  opBuilder.create<scf::YieldOp>(loc, castResultTensor);
817  };
818 
819  // Emit 'else' region of 'scf.if'
820  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
821  opBuilder.create<scf::YieldOp>(loc, operand);
822  };
823 
824  // Emit 'scf.if' op
825  auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
826  emitThenRegion, emitElseRegion);
827  return ifOp.getResult(0);
828 }
829 
831  IndexPool &indexPool, Value operand,
832  ArrayRef<OpFoldResult> targetShape,
833  ArrayRef<Value> masterOperands) {
834  int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
835  assert((int64_t)targetShape.size() == rank);
836  assert((int64_t)masterOperands.size() == rank);
837  for (auto index : llvm::seq<int64_t>(0, rank))
838  operand =
839  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
840  targetShape[index], masterOperands[index]);
841  return operand;
842 }
843 
844 static SmallVector<Value>
846  IndexPool &indexPool, ValueRange operands,
847  ArrayRef<OpFoldResult> targetShape,
848  ArrayRef<Value> masterOperands) {
849  // No need to broadcast for unary operations
850  if (operands.size() == 1)
851  return operands;
852 
853  // Broadcast dynamic dimensions operand by operand
854  return llvm::map_to_vector(operands, [&](Value operand) {
855  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
856  targetShape, masterOperands);
857  });
858 }
859 
860 static LogicalResult
862  Operation *operation, ValueRange operands,
863  ArrayRef<OpFoldResult> targetShape,
864  const TypeConverter &converter) {
865  // Generate output tensor
866  auto resultType = cast_or_null<RankedTensorType>(
867  converter.convertType(operation->getResultTypes().front()));
868  if (!resultType) {
869  return rewriter.notifyMatchFailure(operation, "failed to convert type");
870  }
871  Value outputTensor = rewriter.create<tensor::EmptyOp>(
872  loc, targetShape, resultType.getElementType());
873 
874  // Create affine maps. Input affine maps broadcast static dimensions of size
875  // 1. The output affine map is an identity map.
876  //
877  auto rank = resultType.getRank();
878  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
879  auto shape = cast<ShapedType>(operand.getType()).getShape();
880  SmallVector<AffineExpr> affineExprs;
881  for (auto it : llvm::enumerate(shape)) {
882  auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
883  : rewriter.getAffineDimExpr(it.index());
884  affineExprs.push_back(affineExpr);
885  }
886  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
887  });
888  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
889 
890  // Emit 'linalg.generic' op
891  bool encounteredError = false;
892  auto linalgOp = rewriter.create<linalg::GenericOp>(
893  loc, outputTensor.getType(), operands, outputTensor, affineMaps,
895  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
897  operation, blockArgs.take_front(operation->getNumOperands()),
898  {resultType.getElementType()}, rewriter);
899  if (!opResult) {
900  encounteredError = true;
901  return;
902  }
903  opBuilder.create<linalg::YieldOp>(loc, opResult);
904  });
905  if (encounteredError)
906  return rewriter.notifyMatchFailure(
907  operation, "unable to create linalg.generic body for elementwise op");
908 
909  // Cast 'linalg.generic' result into original result type if needed
910  auto castResult = rewriter.createOrFold<tensor::CastOp>(
911  loc, resultType, linalgOp->getResult(0));
912  rewriter.replaceOp(operation, castResult);
913  return success();
914 }
915 
916 static LogicalResult
918  ConversionPatternRewriter &rewriter,
919  const TypeConverter &converter) {
920 
921  // Collect op properties
922  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
923  assert(operation->getNumOperands() >= 1 &&
924  "elementwise op expects at least 1 operand");
925  if (!operandsAndResultsRanked(operation))
926  return rewriter.notifyMatchFailure(operation,
927  "Unranked tensors not supported");
928 
929  // Lower operation
930  IndexPool indexPool;
931  auto loc = operation->getLoc();
932  auto rank =
933  cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
934  auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
935  auto [targetShape, masterOperands] =
936  computeTargetShape(rewriter, loc, indexPool, expandedOperands);
937  auto broadcastOperands = broadcastDynamicDimensions(
938  rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
939  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
940  targetShape, converter);
941 }
942 
943 // Returns the constant initial value for a given reduction operation. The
944 // attribute type varies depending on the element type required.
945 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
946  PatternRewriter &rewriter) {
947  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
948  return rewriter.getFloatAttr(elementTy, 0.0);
949 
950  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
951  return rewriter.getIntegerAttr(elementTy, 0);
952 
953  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
954  return rewriter.getFloatAttr(elementTy, 1.0);
955 
956  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
957  return rewriter.getIntegerAttr(elementTy, 1);
958 
959  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
960  return rewriter.getFloatAttr(
961  elementTy, APFloat::getLargest(
962  cast<FloatType>(elementTy).getFloatSemantics(), false));
963 
964  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
965  return rewriter.getIntegerAttr(
966  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
967 
968  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
969  return rewriter.getFloatAttr(
970  elementTy, APFloat::getLargest(
971  cast<FloatType>(elementTy).getFloatSemantics(), true));
972 
973  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
974  return rewriter.getIntegerAttr(
975  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
976 
977  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
978  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
979 
980  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
981  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
982 
983  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
984  return rewriter.getFloatAttr(
985  elementTy, APFloat::getLargest(
986  cast<FloatType>(elementTy).getFloatSemantics(), true));
987 
988  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
989  return rewriter.getIntegerAttr(
990  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
991 
992  return {};
993 }
994 
995 // Creates the body calculation for a reduction. The operations vary depending
996 // on the input type.
998  ValueRange args,
999  Type elementTy,
1000  PatternRewriter &rewriter) {
1001  Location loc = op->getLoc();
1002  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1003  return rewriter.create<arith::AddFOp>(loc, args);
1004  }
1005 
1006  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1007  return rewriter.create<arith::AddIOp>(loc, args);
1008  }
1009 
1010  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
1011  return rewriter.create<arith::MulFOp>(loc, args);
1012  }
1013 
1014  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
1015  return rewriter.create<arith::MulIOp>(loc, args);
1016  }
1017 
1018  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1019  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1020  }
1021 
1022  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1023  return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1024  }
1025 
1026  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1027  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1028  }
1029 
1030  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1031  return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1032  }
1033 
1034  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1035  return rewriter.create<arith::AndIOp>(loc, args);
1036 
1037  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1038  return rewriter.create<arith::OrIOp>(loc, args);
1039 
1040  return {};
1041 }
1042 
1043 // Performs the match and rewrite for reduction operations. This includes
1044 // declaring a correctly sized initial value, and the linalg.generic operation
1045 // that reduces across the specified axis.
1046 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1047  PatternRewriter &rewriter) {
1048  auto loc = op->getLoc();
1049  auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
1050  auto resultTy = cast<ShapedType>(op->getResult(0).getType());
1051  auto elementTy = resultTy.getElementType();
1052  Value input = op->getOperand(0);
1053 
1054  SmallVector<int64_t> reduceShape;
1055  SmallVector<Value> dynDims;
1056  for (unsigned i = 0; i < inputTy.getRank(); i++) {
1057  if (axis != i) {
1058  reduceShape.push_back(inputTy.getDimSize(i));
1059  if (inputTy.isDynamicDim(i))
1060  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1061  }
1062  }
1063 
1064  // First fill the output buffer with the init value.
1065  auto emptyTensor =
1066  rewriter
1067  .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1068  dynDims)
1069  .getResult();
1070 
1071  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1072  if (!fillValueAttr)
1073  return rewriter.notifyMatchFailure(
1074  op, "No initial value found for reduction operation");
1075 
1076  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1077  auto filledTensor = rewriter
1078  .create<linalg::FillOp>(loc, ValueRange{fillValue},
1079  ValueRange{emptyTensor})
1080  .result();
1081 
1082  bool didEncounterError = false;
1083  auto linalgOp = rewriter.create<linalg::ReduceOp>(
1084  loc, input, filledTensor, axis,
1085  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1087  op, blockArgs, elementTy, rewriter);
1088  if (result)
1089  didEncounterError = true;
1090 
1091  nestedBuilder.create<linalg::YieldOp>(loc, result);
1092  });
1093 
1094  if (!didEncounterError)
1095  return rewriter.notifyMatchFailure(
1096  op, "unable to create linalg.generic body for reduce op");
1097 
1098  SmallVector<ReassociationExprs, 4> reassociationMap;
1099  uint64_t expandInputRank =
1100  cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1101  reassociationMap.resize(expandInputRank);
1102 
1103  for (uint64_t i = 0; i < expandInputRank; i++) {
1104  int32_t dimToPush = i > axis ? i + 1 : i;
1105  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1106  }
1107 
1108  if (expandInputRank != 0) {
1109  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1110  reassociationMap[expandedDim].push_back(
1111  rewriter.getAffineDimExpr(expandedDim + 1));
1112  }
1113 
1114  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1115  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1116  // not have access to such information. This matters when handling dynamically
1117  // sized tensors.
1118  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1119  op, resultTy, linalgOp.getResults()[0], reassociationMap);
1120  return success();
1121 }
1122 
1123 namespace {
1124 
1125 template <typename SrcOp>
1126 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1127 public:
1130 
1131  LogicalResult
1132  matchAndRewrite(SrcOp op, OpAdaptor operands,
1133  ConversionPatternRewriter &rewriter) const final {
1135  op, operands.getOperands(), rewriter, *this->getTypeConverter());
1136  }
1137 };
1138 
1139 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1140 public:
1142 
1143  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1144  PatternRewriter &rewriter) const final {
1145  auto loc = op.getLoc();
1146  auto input = op.getInput();
1147  auto inputTy = cast<ShapedType>(op.getInput().getType());
1148  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1149  unsigned rank = inputTy.getRank();
1150 
1151  // This is an illegal configuration. terminate and log an error
1152  if (op.getDoubleRound() && !op.getScale32())
1153  return rewriter.notifyMatchFailure(
1154  op, "tosa.rescale requires scale32 for double_round to be true");
1155 
1156  if (!isa<IntegerType>(inputTy.getElementType()))
1157  return rewriter.notifyMatchFailure(op, "only support integer type");
1158 
1159  SmallVector<Value> dynDims;
1160  for (int i = 0; i < outputTy.getRank(); i++) {
1161  if (outputTy.isDynamicDim(i)) {
1162  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1163  }
1164  }
1165 
1166  // The shift and multiplier values.
1167  SmallVector<int32_t> multiplierValues(op.getMultiplier());
1168  SmallVector<int8_t> shiftValues(op.getShift());
1169 
1170  // If we shift by more than the bitwidth, this just sets to 0.
1171  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1172  if (shiftValues[i] > 63) {
1173  shiftValues[i] = 0;
1174  multiplierValues[i] = 0;
1175  }
1176  }
1177 
1178  // Double round only occurs if shift is greater than 31, check that this
1179  // is ever true.
1180  bool doubleRound =
1181  op.getDoubleRound() &&
1182  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1183 
1184  SmallVector<AffineMap> indexingMaps = {
1185  rewriter.getMultiDimIdentityMap(rank)};
1186  SmallVector<Value, 4> genericInputs = {input};
1187 
1188  // If we are rescaling per-channel then we need to store the multiplier
1189  // values in a buffer.
1190  Value multiplierConstant;
1191  int64_t multiplierArg = 0;
1192  if (multiplierValues.size() == 1) {
1193  multiplierConstant = rewriter.create<arith::ConstantOp>(
1194  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1195  } else {
1196  SmallVector<AffineExpr, 2> multiplierExprs{
1197  rewriter.getAffineDimExpr(rank - 1)};
1198  auto multiplierType =
1199  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1200  rewriter.getI32Type());
1201  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1202  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1203 
1204  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1205  /*symbolCount=*/0, multiplierExprs,
1206  rewriter.getContext()));
1207 
1208  multiplierArg = indexingMaps.size() - 1;
1209  }
1210 
1211  // If we are rescaling per-channel then we need to store the shift
1212  // values in a buffer.
1213  Value shiftConstant;
1214  int64_t shiftArg = 0;
1215  if (shiftValues.size() == 1) {
1216  shiftConstant = rewriter.create<arith::ConstantOp>(
1217  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1218  } else {
1219  SmallVector<AffineExpr, 2> shiftExprs = {
1220  rewriter.getAffineDimExpr(rank - 1)};
1221  auto shiftType =
1222  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1223  rewriter.getIntegerType(8));
1224  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1225  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1226  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1227  /*symbolCount=*/0, shiftExprs,
1228  rewriter.getContext()));
1229  shiftArg = indexingMaps.size() - 1;
1230  }
1231 
1232  // Indexing maps for output values.
1233  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1234 
1235  // Construct the indexing maps needed for linalg.generic ops.
1236  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1237  loc, outputTy.getShape(), outputTy.getElementType(),
1238  ArrayRef<Value>({dynDims}));
1239 
1240  auto linalgOp = rewriter.create<linalg::GenericOp>(
1241  loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1242  getNParallelLoopsAttrs(rank),
1243  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1244  ValueRange blockArgs) {
1245  Value value = blockArgs[0];
1246  Type valueTy = value.getType();
1247 
1248  // For now we do all of our math in 64-bit. This is not optimal but
1249  // should be correct for now, consider computing correct bit depth
1250  // later.
1251  int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1252 
1253  auto inputZp = createConstFromIntAttribute<int32_t>(
1254  op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1255  nestedBuilder);
1256  auto outputZp = createConstFromIntAttribute<int32_t>(
1257  op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1258 
1259  Value multiplier = multiplierConstant ? multiplierConstant
1260  : blockArgs[multiplierArg];
1261  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1262 
1263  if (valueTy.getIntOrFloatBitWidth() < 32) {
1264  if (op.getInputUnsigned()) {
1265  value = nestedBuilder.create<arith::ExtUIOp>(
1266  nestedLoc, nestedBuilder.getI32Type(), value);
1267  } else {
1268  value = nestedBuilder.create<arith::ExtSIOp>(
1269  nestedLoc, nestedBuilder.getI32Type(), value);
1270  }
1271  }
1272 
1273  value =
1274  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1275 
1276  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1277  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1278  nestedBuilder.getBoolAttr(doubleRound));
1279 
1280  // Move to the new zero-point.
1281  value =
1282  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1283 
1284  // Saturate to the output size.
1285  IntegerType outIntType =
1286  cast<IntegerType>(blockArgs.back().getType());
1287  unsigned outBitWidth = outIntType.getWidth();
1288 
1289  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1290  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1291 
1292  // Unsigned integers have a difference output value.
1293  if (op.getOutputUnsigned()) {
1294  intMin = 0;
1295  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1296  }
1297 
1298  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1299  loc, nestedBuilder.getI32IntegerAttr(intMin));
1300  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1301  loc, nestedBuilder.getI32IntegerAttr(intMax));
1302 
1303  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1304  nestedBuilder, /*isUnsigned=*/false);
1305 
1306  if (outIntType.getWidth() < 32) {
1307  value = nestedBuilder.create<arith::TruncIOp>(
1308  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1309  value);
1310  }
1311 
1312  nestedBuilder.create<linalg::YieldOp>(loc, value);
1313  });
1314 
1315  rewriter.replaceOp(op, linalgOp->getResults());
1316  return success();
1317  }
1318 };
1319 
1320 // Handle the resize case where the input is a 1x1 image. This case
1321 // can entirely avoiding having extract operations which target much
1322 // more difficult to optimize away.
1323 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1324 public:
1326 
1327  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1328  PatternRewriter &rewriter) const final {
1329  Location loc = op.getLoc();
1330  ImplicitLocOpBuilder builder(loc, rewriter);
1331  auto input = op.getInput();
1332  auto inputTy = cast<RankedTensorType>(input.getType());
1333  auto resultTy = cast<RankedTensorType>(op.getType());
1334  const bool isBilinear = op.getMode() == "BILINEAR";
1335 
1336  auto inputH = inputTy.getDimSize(1);
1337  auto inputW = inputTy.getDimSize(2);
1338  auto outputH = resultTy.getDimSize(1);
1339  auto outputW = resultTy.getDimSize(2);
1340 
1341  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1342  return rewriter.notifyMatchFailure(
1343  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1344 
1345  // TODO(suderman): These string values should be declared the TOSA dialect.
1346  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1347  return rewriter.notifyMatchFailure(
1348  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1349 
1350  if (inputTy == resultTy) {
1351  rewriter.replaceOp(op, input);
1352  return success();
1353  }
1354 
1355  ArrayRef<int64_t> scale = op.getScale();
1356 
1357  // Collapse the unit width and height away.
1358  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1359  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1360  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1361  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1362  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1363 
1364  auto collapseTy =
1365  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1366  inputTy.getElementType());
1367  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1368  reassociationMap);
1369 
1370  // Get any dynamic shapes that appear in the input format.
1371  llvm::SmallVector<Value> outputDynSize;
1372  if (inputTy.isDynamicDim(0))
1373  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1374  if (inputTy.isDynamicDim(3))
1375  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1376 
1377  // Generate the elementwise operation for casting scaling the input value.
1378  auto genericTy = collapseTy.clone(resultTy.getElementType());
1379  Value empty = builder.create<tensor::EmptyOp>(
1380  genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1381  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1382  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1383  utils::IteratorType::parallel);
1384 
1385  auto generic = builder.create<linalg::GenericOp>(
1386  genericTy, ValueRange{collapse}, ValueRange{empty},
1387  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1388  [=](OpBuilder &b, Location loc, ValueRange args) {
1389  Value value = args[0];
1390  // This is the quantized case.
1391  if (inputTy.getElementType() != resultTy.getElementType()) {
1392  value =
1393  b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1394 
1395  if (isBilinear && scale[0] != 0) {
1396  Value scaleY = b.create<arith::ConstantOp>(
1397  loc, b.getI32IntegerAttr(scale[0]));
1398  value = b.create<arith::MulIOp>(loc, value, scaleY);
1399  }
1400 
1401  if (isBilinear && scale[2] != 0) {
1402  Value scaleX = b.create<arith::ConstantOp>(
1403  loc, b.getI32IntegerAttr(scale[2]));
1404  value = b.create<arith::MulIOp>(loc, value, scaleX);
1405  }
1406  }
1407 
1408  b.create<linalg::YieldOp>(loc, value);
1409  });
1410 
1411  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1412  op, resultTy, generic.getResults()[0], reassociationMap);
1413  return success();
1414  }
1415 };
1416 
1417 // TOSA resize with width or height of 1 may be broadcasted to a wider
1418 // dimension. This is done by materializing a new tosa.resize without
1419 // the broadcasting behavior, and an explicit broadcast afterwards.
1420 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1421 public:
1423 
1424  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1425  PatternRewriter &rewriter) const final {
1426  Location loc = op.getLoc();
1427  ImplicitLocOpBuilder builder(loc, rewriter);
1428  auto input = op.getInput();
1429  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1430  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1431 
1432  if (!inputTy || !resultTy)
1433  return rewriter.notifyMatchFailure(op,
1434  "requires ranked input/output types");
1435 
1436  auto batch = inputTy.getDimSize(0);
1437  auto channels = inputTy.getDimSize(3);
1438  auto inputH = inputTy.getDimSize(1);
1439  auto inputW = inputTy.getDimSize(2);
1440  auto outputH = resultTy.getDimSize(1);
1441  auto outputW = resultTy.getDimSize(2);
1442 
1443  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1444  return rewriter.notifyMatchFailure(
1445  op, "tosa.resize has no broadcasting behavior");
1446 
1447  // For any dimension that is broadcastable we generate a width of 1
1448  // on the output.
1449  llvm::SmallVector<int64_t> resizeShape;
1450  resizeShape.push_back(batch);
1451  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1452  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1453  resizeShape.push_back(channels);
1454 
1455  auto resizeTy = resultTy.clone(resizeShape);
1456  auto resize =
1457  builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1458 
1459  // Collapse an unit result dims.
1460  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1461  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1462  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1463  if (inputH != 1)
1464  reassociationMap.push_back({});
1465  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1466  if (inputW != 1)
1467  reassociationMap.push_back({});
1468  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1469 
1470  llvm::SmallVector<int64_t> collapseShape{batch};
1471  if (inputH != 1)
1472  collapseShape.push_back(outputH);
1473  if (inputW != 1)
1474  collapseShape.push_back(outputW);
1475  collapseShape.push_back(channels);
1476 
1477  auto collapseTy = resultTy.clone(collapseShape);
1478  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1479  reassociationMap);
1480 
1481  // Broadcast the collapsed shape to the output result.
1482  llvm::SmallVector<Value> outputDynSize;
1483  if (inputTy.isDynamicDim(0))
1484  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1485  if (inputTy.isDynamicDim(3))
1486  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1487 
1488  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1489  utils::IteratorType::parallel);
1490  Value empty = builder.create<tensor::EmptyOp>(
1491  resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1492 
1493  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1494  if (inputH != 1)
1495  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1496  if (inputW != 1)
1497  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1498  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1499 
1500  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1501  inputExprs, rewriter.getContext());
1502 
1503  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1504  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1505  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1506  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1507  [=](OpBuilder &b, Location loc, ValueRange args) {
1508  Value value = args[0];
1509  b.create<linalg::YieldOp>(loc, value);
1510  });
1511 
1512  return success();
1513  }
1514 };
1515 
1516 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1517 public:
1519 
1520  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1521  PatternRewriter &rewriter) const final {
1522  Location loc = op.getLoc();
1523  ImplicitLocOpBuilder b(loc, rewriter);
1524  auto input = op.getInput();
1525  auto inputTy = cast<ShapedType>(input.getType());
1526  auto resultTy = cast<ShapedType>(op.getType());
1527  auto resultETy = resultTy.getElementType();
1528 
1529  bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1530  auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1531 
1532  auto imageH = inputTy.getShape()[1];
1533  auto imageW = inputTy.getShape()[2];
1534 
1535  auto dynamicDimsOr =
1536  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1537  if (!dynamicDimsOr.has_value())
1538  return rewriter.notifyMatchFailure(
1539  op, "unable to get dynamic dimensions of tosa.resize");
1540 
1541  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1542  return rewriter.notifyMatchFailure(
1543  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1544 
1545  SmallVector<AffineMap, 2> affineMaps = {
1546  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1547  auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1548  *dynamicDimsOr);
1549  auto genericOp = b.create<linalg::GenericOp>(
1550  resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1551  getNParallelLoopsAttrs(resultTy.getRank()));
1552  Value resize = genericOp.getResult(0);
1553 
1554  {
1555  OpBuilder::InsertionGuard regionGuard(b);
1556  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1557  TypeRange({resultETy}), loc);
1558  Value batch = b.create<linalg::IndexOp>(0);
1559  Value y = b.create<linalg::IndexOp>(1);
1560  Value x = b.create<linalg::IndexOp>(2);
1561  Value channel = b.create<linalg::IndexOp>(3);
1562 
1563  Value zeroI32 =
1564  b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1565  Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1566  Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1567  Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1568 
1569  Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1570  Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1571 
1572  ArrayRef<int64_t> offset = op.getOffset();
1573  ArrayRef<int64_t> border = op.getBorder();
1574  ArrayRef<int64_t> scale = op.getScale();
1575 
1576  Value yScaleN, yScaleD, xScaleN, xScaleD;
1577  yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1578  yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1579  xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1580  xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1581 
1582  Value yOffset, xOffset, yBorder, xBorder;
1583  yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1584  xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1585  yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1586  xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1587 
1588  // Compute the ix and dx values for both the X and Y dimensions.
1589  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1590  Value scaleN, Value scaleD, Value offset,
1591  int size, ImplicitLocOpBuilder &b) {
1592  if (size == 1) {
1593  index = zeroI32;
1594  delta = zeroFp;
1595  return;
1596  }
1597  // x = x * scale_d + offset;
1598  // ix = floor(x / scale_n)
1599  Value val = b.create<arith::MulIOp>(in, scaleD);
1600  val = b.create<arith::AddIOp>(val, offset);
1601  index = b.create<arith::FloorDivSIOp>(val, scaleN);
1602 
1603  // rx = x % scale_n
1604  // dx = rx / scale_n
1605  Value r = b.create<arith::RemSIOp>(val, scaleN);
1606  Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1607  Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1608  delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1609  };
1610 
1611  // Compute the ix and dx values for the X and Y dimensions - int case.
1612  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1613  Value scaleN, Value scaleD, Value offset,
1614  int size, ImplicitLocOpBuilder &b) {
1615  if (size == 1) {
1616  index = zeroI32;
1617  delta = zeroI32;
1618  return;
1619  }
1620  // x = x * scale_d + offset;
1621  // ix = floor(x / scale_n)
1622  // dx = x - ix * scale_n;
1623  Value val = b.create<arith::MulIOp>(in, scaleD);
1624  val = b.create<arith::AddIOp>(val, offset);
1625  index = b.create<arith::DivSIOp>(val, scaleN);
1626  delta = b.create<arith::MulIOp>(index, scaleN);
1627  delta = b.create<arith::SubIOp>(val, delta);
1628  };
1629 
1630  Value ix, iy, dx, dy;
1631  if (floatingPointMode) {
1632  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1633  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1634  } else {
1635  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1636  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1637  }
1638 
1639  if (op.getMode() == "NEAREST_NEIGHBOR") {
1640  auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1641 
1642  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1643  Value max, int size,
1644  ImplicitLocOpBuilder &b) -> Value {
1645  if (size == 1) {
1646  return b.create<arith::ConstantIndexOp>(0);
1647  }
1648 
1649  Value pred;
1650  if (floatingPointMode) {
1651  auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1652  pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1653  } else {
1654  Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1655  pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1656  dvalDouble, scale);
1657  }
1658 
1659  auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1660  val = b.create<arith::AddIOp>(val, offset);
1661  val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
1662  return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1663  };
1664 
1665  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1666  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1667 
1668  Value result = b.create<tensor::ExtractOp>(
1669  input, ValueRange{batch, iy, ix, channel});
1670 
1671  b.create<linalg::YieldOp>(result);
1672  } else {
1673  // The mode here must be BILINEAR.
1674  assert(op.getMode() == "BILINEAR");
1675 
1676  auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1677 
1678  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1680  val0 = in;
1681  val1 = b.create<arith::AddIOp>(val0, oneVal);
1682  val0 =
1683  clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1684  val1 =
1685  clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
1686  val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1687  val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1688  };
1689 
1690  // Linalg equivalent to the section below:
1691  // int16_t iy0 = apply_max(iy, 0);
1692  // int16_t iy1 = apply_min(iy + 1, IH - 1);
1693  // int16_t ix0 = apply_max(ix, 0);
1694  // int16_t ix1 = apply_min(ix + 1, IW - 1);
1695  Value x0, x1, y0, y1;
1696  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1697  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1698 
1699  Value y0x0 = b.create<tensor::ExtractOp>(
1700  input, ValueRange{batch, y0, x0, channel});
1701  Value y0x1 = b.create<tensor::ExtractOp>(
1702  input, ValueRange{batch, y0, x1, channel});
1703  Value y1x0 = b.create<tensor::ExtractOp>(
1704  input, ValueRange{batch, y1, x0, channel});
1705  Value y1x1 = b.create<tensor::ExtractOp>(
1706  input, ValueRange{batch, y1, x1, channel});
1707 
1708  if (floatingPointMode) {
1709  auto oneVal =
1710  b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1711  auto interpolate = [&](Value val0, Value val1, Value delta,
1712  int inputSize,
1713  ImplicitLocOpBuilder &b) -> Value {
1714  if (inputSize == 1)
1715  return val0;
1716  Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1717  Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1718  Value mul1 = b.create<arith::MulFOp>(val1, delta);
1719  return b.create<arith::AddFOp>(mul0, mul1);
1720  };
1721 
1722  // Linalg equivalent to the section below:
1723  // topAcc = v00 * (unit_x - dx);
1724  // topAcc += v01 * dx;
1725  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1726 
1727  // Linalg equivalent to the section below:
1728  // bottomAcc = v10 * (unit_x - dx);
1729  // bottomAcc += v11 * dx;
1730  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1731 
1732  // Linalg equivalent to the section below:
1733  // result = topAcc * (unit_y - dy) + bottomAcc * dy
1734  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1735  b.create<linalg::YieldOp>(result);
1736  } else {
1737  // Perform in quantized space.
1738  y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1739  y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1740  y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1741  y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1742 
1743  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1744  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1745  dx = b.create<arith::ExtSIOp>(resultETy, dx);
1746  dy = b.create<arith::ExtSIOp>(resultETy, dy);
1747  }
1748 
1749  Value yScaleNExt = yScaleN;
1750  Value xScaleNExt = xScaleN;
1751 
1752  const int64_t scaleBitwidth =
1753  xScaleN.getType().getIntOrFloatBitWidth();
1754  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1755  yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1756  xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1757  }
1758 
1759  auto interpolate = [](Value val0, Value val1, Value weight1,
1760  Value scale, int inputSize,
1761  ImplicitLocOpBuilder &b) -> Value {
1762  if (inputSize == 1)
1763  return b.create<arith::MulIOp>(val0, scale);
1764  Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1765  Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1766  Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1767  return b.create<arith::AddIOp>(mul0, mul1);
1768  };
1769 
1770  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1771  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1772  Value result =
1773  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1774  b.create<linalg::YieldOp>(result);
1775  }
1776  }
1777  }
1778 
1779  rewriter.replaceOp(op, resize);
1780  return success();
1781  }
1782 };
1783 
1784 // At the codegen level any identity operations should be removed. Any cases
1785 // where identity is load-bearing (e.g. cross device computation) should be
1786 // handled before lowering to codegen.
1787 template <typename SrcOp>
1788 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1789 public:
1791 
1792  LogicalResult matchAndRewrite(SrcOp op,
1793  PatternRewriter &rewriter) const final {
1794  rewriter.replaceOp(op, op.getOperation()->getOperands());
1795  return success();
1796  }
1797 };
1798 
1799 template <typename SrcOp>
1800 class ReduceConverter : public OpRewritePattern<SrcOp> {
1801 public:
1803 
1804  LogicalResult matchAndRewrite(SrcOp reduceOp,
1805  PatternRewriter &rewriter) const final {
1806  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1807  }
1808 };
1809 
1810 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1811 public:
1813 
1814  LogicalResult matchAndRewrite(tosa::ReverseOp op,
1815  PatternRewriter &rewriter) const final {
1816  auto loc = op.getLoc();
1817  Value input = op.getInput1();
1818  auto inputTy = cast<ShapedType>(input.getType());
1819  auto resultTy = cast<ShapedType>(op.getType());
1820  auto axis = op.getAxis();
1821 
1822  SmallVector<Value> dynDims;
1823  for (int i = 0; i < inputTy.getRank(); i++) {
1824  if (inputTy.isDynamicDim(i)) {
1825  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1826  }
1827  }
1828 
1829  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1830 
1831  // First fill the output buffer with the init value.
1832  auto emptyTensor = rewriter
1833  .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1834  inputTy.getElementType(),
1835  ArrayRef<Value>({dynDims}))
1836  .getResult();
1837  SmallVector<AffineMap, 2> affineMaps = {
1838  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1839 
1840  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1841  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
1842  getNParallelLoopsAttrs(resultTy.getRank()),
1843  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1844  llvm::SmallVector<Value> indices;
1845  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1846  Value index =
1847  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1848  if (i == axis) {
1849  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1850  auto sizeMinusOne =
1851  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1852  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1853  index);
1854  }
1855 
1856  indices.push_back(index);
1857  }
1858 
1859  auto extract = nestedBuilder.create<tensor::ExtractOp>(
1860  nestedLoc, input, indices);
1861  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1862  extract.getResult());
1863  });
1864  return success();
1865  }
1866 };
1867 
1868 // This converter translate a tile operation to a reshape, broadcast, reshape.
1869 // The first reshape minimally expands each tiled dimension to include a
1870 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1871 // multiple.
1872 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1874 
1875  LogicalResult
1876  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1877  ConversionPatternRewriter &rewriter) const override {
1878  auto loc = op.getLoc();
1879  auto input = op.getInput1();
1880  auto inputTy = cast<ShapedType>(input.getType());
1881  auto inputShape = inputTy.getShape();
1882  auto resultTy = cast<ShapedType>(op.getType());
1883  auto elementTy = inputTy.getElementType();
1884  int64_t rank = inputTy.getRank();
1885 
1886  ArrayRef<int64_t> multiples = op.getMultiples();
1887 
1888  // Broadcast the newly added dimensions to their appropriate multiple.
1889  SmallVector<int64_t, 2> genericShape;
1890  for (int i = 0; i < rank; i++) {
1891  int64_t dim = multiples[i];
1892  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1893  genericShape.push_back(inputShape[i]);
1894  }
1895 
1896  SmallVector<Value> dynDims;
1897  for (int i = 0; i < inputTy.getRank(); i++) {
1898  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1899  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1900  }
1901  }
1902 
1903  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1904  op.getLoc(), genericShape, elementTy, dynDims);
1905 
1906  // We needs to map the input shape to the non-broadcasted dimensions.
1907  SmallVector<AffineExpr, 4> dimExprs;
1908  dimExprs.reserve(rank);
1909  for (unsigned i = 0; i < rank; ++i)
1910  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1911 
1912  auto readAffineMap =
1913  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1914  rewriter.getContext());
1915 
1916  SmallVector<AffineMap, 2> affineMaps = {
1917  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1918 
1919  auto genericOp = rewriter.create<linalg::GenericOp>(
1920  loc, RankedTensorType::get(genericShape, elementTy), input,
1921  ValueRange{emptyTensor}, affineMaps,
1922  getNParallelLoopsAttrs(genericShape.size()),
1923  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1924  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1925  });
1926 
1927  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1928  op, resultTy, genericOp.getResult(0),
1929  rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
1930  return success();
1931  }
1932 };
1933 
1934 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1935 // op, producing two output buffers.
1936 //
1937 // The first output buffer contains the index of the found maximum value. It is
1938 // initialized to 0 and is resulting integer type.
1939 //
1940 // The second output buffer contains the maximum value found. It is initialized
1941 // to the minimum representable value of the input element type. After being
1942 // populated by indexed_generic, this buffer is disgarded as only the index is
1943 // requested.
1944 //
1945 // The indexed_generic op updates both the maximum value and index if the
1946 // current value exceeds the running max.
1947 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1948 public:
1950 
1951  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1952  PatternRewriter &rewriter) const final {
1953  auto loc = argmaxOp.getLoc();
1954  Value input = argmaxOp.getInput();
1955  auto inputTy = cast<ShapedType>(input.getType());
1956  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1957  auto inElementTy = inputTy.getElementType();
1958  auto outElementTy = resultTy.getElementType();
1959  int axis = argmaxOp.getAxis();
1960  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1961 
1962  if (!isa<IntegerType>(outElementTy))
1963  return rewriter.notifyMatchFailure(
1964  argmaxOp,
1965  "tosa.arg_max to linalg.* requires integer-like result type");
1966 
1967  SmallVector<Value> dynDims;
1968  for (int i = 0; i < inputTy.getRank(); i++) {
1969  if (inputTy.isDynamicDim(i) && i != axis) {
1970  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1971  }
1972  }
1973 
1974  // First fill the output buffer for the index.
1975  auto emptyTensorIdx = rewriter
1976  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1977  outElementTy, dynDims)
1978  .getResult();
1979  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1980  loc, rewriter.getIntegerAttr(outElementTy, 0));
1981  auto filledTensorIdx =
1982  rewriter
1983  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
1984  ValueRange{emptyTensorIdx})
1985  .result();
1986 
1987  // Second fill the output buffer for the running max.
1988  auto emptyTensorMax = rewriter
1989  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1990  inElementTy, dynDims)
1991  .getResult();
1992  auto fillValueMaxAttr =
1993  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
1994 
1995  if (!fillValueMaxAttr)
1996  return rewriter.notifyMatchFailure(
1997  argmaxOp, "unsupported tosa.argmax element type");
1998 
1999  auto fillValueMax =
2000  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2001  auto filledTensorMax =
2002  rewriter
2003  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2004  ValueRange{emptyTensorMax})
2005  .result();
2006 
2007  // We need to reduce along the arg-max axis, with parallel operations along
2008  // the rest.
2010  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2011  iteratorTypes[axis] = utils::IteratorType::reduction;
2012 
2013  SmallVector<AffineExpr, 2> srcExprs;
2014  SmallVector<AffineExpr, 2> dstExprs;
2015  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2016  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2017  if (axis != i)
2018  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2019  }
2020 
2021  bool didEncounterError = false;
2022  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2023  rewriter.getContext());
2024  auto linalgOp = rewriter.create<linalg::GenericOp>(
2025  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2026  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2027  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2028  ValueRange blockArgs) {
2029  auto newValue = blockArgs[0];
2030  auto oldIndex = blockArgs[1];
2031  auto oldValue = blockArgs[2];
2032 
2033  Value newIndex = rewriter.create<arith::IndexCastOp>(
2034  nestedLoc, oldIndex.getType(),
2035  rewriter.create<linalg::IndexOp>(loc, axis));
2036 
2037  Value predicate;
2038  if (isa<FloatType>(inElementTy)) {
2039  predicate = rewriter.create<arith::CmpFOp>(
2040  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2041  } else if (isa<IntegerType>(inElementTy)) {
2042  predicate = rewriter.create<arith::CmpIOp>(
2043  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2044  } else {
2045  didEncounterError = true;
2046  return;
2047  }
2048 
2049  auto resultMax = rewriter.create<arith::SelectOp>(
2050  nestedLoc, predicate, newValue, oldValue);
2051  auto resultIndex = rewriter.create<arith::SelectOp>(
2052  nestedLoc, predicate, newIndex, oldIndex);
2053  nestedBuilder.create<linalg::YieldOp>(
2054  nestedLoc, ValueRange({resultIndex, resultMax}));
2055  });
2056 
2057  if (didEncounterError)
2058  return rewriter.notifyMatchFailure(
2059  argmaxOp, "unsupported tosa.argmax element type");
2060 
2061  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2062  return success();
2063  }
2064 };
2065 
2066 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2067 public:
2069  LogicalResult
2070  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2071  ConversionPatternRewriter &rewriter) const final {
2072  auto input = adaptor.getOperands()[0];
2073  auto indices = adaptor.getOperands()[1];
2074 
2075  auto valuesTy =
2076  dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2077  auto resultTy = cast<ShapedType>(op.getType());
2078 
2079  if (!valuesTy)
2080  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2081 
2082  auto dynamicDims = inferDynamicDimsForGather(
2083  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2084 
2085  auto resultElementTy = resultTy.getElementType();
2086 
2087  auto loc = op.getLoc();
2088  auto emptyTensor =
2089  rewriter
2090  .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2091  dynamicDims)
2092  .getResult();
2093 
2094  SmallVector<AffineMap, 2> affineMaps = {
2096  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2097  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2098  rewriter.getContext()),
2099  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2100 
2101  auto genericOp = rewriter.create<linalg::GenericOp>(
2102  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2103  ValueRange{emptyTensor}, affineMaps,
2104  getNParallelLoopsAttrs(resultTy.getRank()),
2105  [&](OpBuilder &b, Location loc, ValueRange args) {
2106  auto indexValue = args[0];
2107  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2108  Value index1 = rewriter.create<arith::IndexCastOp>(
2109  loc, rewriter.getIndexType(), indexValue);
2110  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2111  Value extract = rewriter.create<tensor::ExtractOp>(
2112  loc, input, ValueRange{index0, index1, index2});
2113  rewriter.create<linalg::YieldOp>(loc, extract);
2114  });
2115  rewriter.replaceOp(op, genericOp.getResult(0));
2116  return success();
2117  }
2118 
2119  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2120  Location loc,
2121  Value values,
2122  Value indices) {
2123  llvm::SmallVector<Value> results;
2124 
2125  auto addDynamicDimension = [&](Value source, int64_t dim) {
2126  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2127  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2128  results.push_back(dimValue);
2129  };
2130 
2131  addDynamicDimension(values, 0);
2132  addDynamicDimension(indices, 1);
2133  addDynamicDimension(values, 2);
2134  return results;
2135  }
2136 };
2137 
2138 // Lowerings the TableOp to a series of gathers and numerica operations. This
2139 // includes interpolation between the high/low values. For the I8 varient, this
2140 // simplifies to a single gather operation.
2141 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2142 public:
2144 
2145  LogicalResult matchAndRewrite(tosa::TableOp op,
2146  PatternRewriter &rewriter) const final {
2147  auto loc = op.getLoc();
2148  Value input = op.getInput1();
2149  Value table = op.getTable();
2150  auto inputTy = cast<ShapedType>(input.getType());
2151  auto tableTy = cast<ShapedType>(table.getType());
2152  auto resultTy = cast<ShapedType>(op.getType());
2153 
2154  auto inputElementTy = inputTy.getElementType();
2155  auto tableElementTy = tableTy.getElementType();
2156  auto resultElementTy = resultTy.getElementType();
2157 
2158  SmallVector<Value> dynDims;
2159  for (int i = 0; i < resultTy.getRank(); ++i) {
2160  if (inputTy.isDynamicDim(i)) {
2161  dynDims.push_back(
2162  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2163  }
2164  }
2165 
2166  auto emptyTensor = rewriter
2167  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2168  resultElementTy, dynDims)
2169  .getResult();
2170 
2171  SmallVector<AffineMap, 2> affineMaps = {
2172  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2173  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2174 
2175  auto genericOp = rewriter.create<linalg::GenericOp>(
2176  loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2177  getNParallelLoopsAttrs(resultTy.getRank()));
2178  rewriter.replaceOp(op, genericOp.getResult(0));
2179 
2180  {
2181  OpBuilder::InsertionGuard regionGuard(rewriter);
2182  Block *block = rewriter.createBlock(
2183  &genericOp.getRegion(), genericOp.getRegion().end(),
2184  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2185 
2186  auto inputValue = block->getArgument(0);
2187  rewriter.setInsertionPointToStart(block);
2188  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2189  resultElementTy.isInteger(8)) {
2190  Value index = rewriter.create<arith::IndexCastOp>(
2191  loc, rewriter.getIndexType(), inputValue);
2192  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2193  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2194  index, offset);
2195  Value extract =
2196  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2197  rewriter.create<linalg::YieldOp>(loc, extract);
2198  return success();
2199  }
2200 
2201  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2202  resultElementTy.isInteger(32)) {
2203  Value extend = rewriter.create<arith::ExtSIOp>(
2204  loc, rewriter.getI32Type(), inputValue);
2205 
2206  auto offset = rewriter.create<arith::ConstantOp>(
2207  loc, rewriter.getI32IntegerAttr(32768));
2208  auto seven = rewriter.create<arith::ConstantOp>(
2209  loc, rewriter.getI32IntegerAttr(7));
2210  auto one = rewriter.create<arith::ConstantOp>(
2211  loc, rewriter.getI32IntegerAttr(1));
2212  auto b1111111 = rewriter.create<arith::ConstantOp>(
2213  loc, rewriter.getI32IntegerAttr(127));
2214 
2215  // Compute the index and fractional part from the input value:
2216  // value = value + 32768
2217  // index = value >> 7;
2218  // fraction = 0x01111111 & value
2219  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2220  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2221  Value fraction =
2222  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2223 
2224  // Extract the base and next values from the table.
2225  // base = (int32_t) table[index];
2226  // next = (int32_t) table[index + 1];
2227  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2228 
2229  index = rewriter.create<arith::IndexCastOp>(
2230  loc, rewriter.getIndexType(), index);
2231  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2232  loc, rewriter.getIndexType(), indexPlusOne);
2233 
2234  Value base =
2235  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2236  Value next = rewriter.create<tensor::ExtractOp>(
2237  loc, table, ValueRange{indexPlusOne});
2238 
2239  base =
2240  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2241  next =
2242  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2243 
2244  // Use the fractional part to interpolate between the input values:
2245  // result = (base << 7) + (next - base) * fraction
2246  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2247  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2248  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2249  Value result =
2250  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2251 
2252  rewriter.create<linalg::YieldOp>(loc, result);
2253 
2254  return success();
2255  }
2256  }
2257 
2258  return rewriter.notifyMatchFailure(
2259  op, "unable to create body for tosa.table op");
2260  }
2261 };
2262 
2263 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2265 
2266  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2267 
2268  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2269  OpFoldResult ofr) {
2270  auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2271  auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2272 
2273  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2274  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2275  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2276  return getAsOpFoldResult(plusOne);
2277  }
2278 
2279  static RankedTensorType
2280  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2281  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2282  // Get [N, H, W]
2283  auto dims = tensor::getMixedSizes(builder, loc, input);
2284 
2285  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2286  // output tensors.
2287  dims[2] = halfPlusOne(builder, loc, dims[2]);
2288 
2289  llvm::SmallVector<int64_t, 3> staticSizes;
2290  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2291 
2292  auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2293  return RankedTensorType::get(staticSizes, elementType);
2294  }
2295 
2296  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2297  RankedTensorType type,
2298  llvm::ArrayRef<Value> dynamicSizes) {
2299  auto emptyTensor =
2300  rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2301  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2302  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2303  auto filledTensor = rewriter
2304  .create<linalg::FillOp>(loc, ValueRange{fillValue},
2305  ValueRange{emptyTensor})
2306  .result();
2307  return filledTensor;
2308  }
2309 
2310  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2311  FloatType type, Value value) {
2312  auto integerVal = builder.create<arith::IndexCastUIOp>(
2313  loc,
2314  type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2315  : builder.getI32Type(),
2316  value);
2317 
2318  return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2319  }
2320 
2321  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2322  FloatType type, int64_t index) {
2323  auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2324  return castIndexToFloat(builder, loc, type, indexVal);
2325  }
2326 
2327  template <typename... Args>
2328  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2329  Args... args) {
2330  return {builder.getAffineDimExpr(args)...};
2331  }
2332 
2333  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2334  PatternRewriter &rewriter) const override {
2335  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2336  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2337  return rewriter.notifyMatchFailure(rfft2d,
2338  "only supports ranked tensors");
2339  }
2340 
2341  auto loc = rfft2d.getLoc();
2342  auto input = rfft2d.getInput();
2343  auto elementType =
2344  dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2345  if (!elementType)
2346  return rewriter.notifyMatchFailure(rfft2d,
2347  "only supports float element types");
2348 
2349  // Compute the output type and set of dynamic sizes
2350  llvm::SmallVector<Value> dynamicSizes;
2351  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2352 
2353  // Iterator types for the linalg.generic implementation
2355  utils::IteratorType::parallel, utils::IteratorType::parallel,
2356  utils::IteratorType::parallel, utils::IteratorType::reduction,
2357  utils::IteratorType::reduction};
2358 
2359  // Inputs/outputs to the linalg.generic implementation
2360  llvm::SmallVector<Value> genericOpInputs = {input};
2361  llvm::SmallVector<Value> genericOpOutputs = {
2362  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2363  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2364 
2365  // Indexing maps for input and output tensors
2366  auto indexingMaps = AffineMap::inferFromExprList(
2367  llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2368  affineDimsExpr(rewriter, 0, 1, 2),
2369  affineDimsExpr(rewriter, 0, 1, 2)},
2370  rewriter.getContext());
2371 
2372  // Width and height dimensions of the original input.
2373  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2374  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2375 
2376  // Constants and dimension sizes
2377  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2378  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2379  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2380  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2381 
2382  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2383  Value valReal = args[0];
2384  Value sumReal = args[1];
2385  Value sumImag = args[2];
2386 
2387  // Indices for angle computation
2388  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2389  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2390  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2391  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2392 
2393  // Calculating angle without integer parts of components as sin/cos are
2394  // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2395  // / W);
2396  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2397  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2398 
2399  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2400  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2401 
2402  auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2403  auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2404 
2405  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2406  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2407  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2408  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2409 
2410  // realComponent = valReal * cos(angle)
2411  // imagComponent = valReal * sin(angle)
2412  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2413  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2414  auto realComponent =
2415  builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2416  auto imagComponent =
2417  builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2418 
2419  // outReal = sumReal + realComponent
2420  // outImag = sumImag - imagComponent
2421  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2422  auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2423 
2424  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2425  };
2426 
2427  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2428  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2429  indexingMaps, iteratorTypes, buildBody);
2430 
2431  return success();
2432  }
2433 };
2434 
2435 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2437 
2438  LogicalResult matchAndRewrite(FFT2dOp fft2d,
2439  PatternRewriter &rewriter) const override {
2440  if (!llvm::all_of(fft2d->getOperandTypes(),
2441  RFFT2dConverter::isRankedTensor) ||
2442  !llvm::all_of(fft2d->getResultTypes(),
2443  RFFT2dConverter::isRankedTensor)) {
2444  return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2445  }
2446 
2447  Location loc = fft2d.getLoc();
2448  Value input_real = fft2d.getInputReal();
2449  Value input_imag = fft2d.getInputImag();
2450  BoolAttr inverse = fft2d.getInverseAttr();
2451 
2452  auto real_el_ty = cast<FloatType>(
2453  cast<ShapedType>(input_real.getType()).getElementType());
2454  [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2455  cast<ShapedType>(input_imag.getType()).getElementType());
2456 
2457  assert(real_el_ty == imag_el_ty);
2458 
2459  // Compute the output type and set of dynamic sizes
2460  SmallVector<Value> dynamicSizes;
2461 
2462  // Get [N, H, W]
2463  auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2464 
2465  SmallVector<int64_t, 3> staticSizes;
2466  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2467 
2468  auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2469 
2470  // Iterator types for the linalg.generic implementation
2471  SmallVector<utils::IteratorType, 5> iteratorTypes = {
2472  utils::IteratorType::parallel, utils::IteratorType::parallel,
2473  utils::IteratorType::parallel, utils::IteratorType::reduction,
2474  utils::IteratorType::reduction};
2475 
2476  // Inputs/outputs to the linalg.generic implementation
2477  SmallVector<Value> genericOpInputs = {input_real, input_imag};
2478  SmallVector<Value> genericOpOutputs = {
2479  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2480  dynamicSizes),
2481  RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2482  dynamicSizes)};
2483 
2484  // Indexing maps for input and output tensors
2485  auto indexingMaps = AffineMap::inferFromExprList(
2486  ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2487  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2488  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2489  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2490  rewriter.getContext());
2491 
2492  // Width and height dimensions of the original input.
2493  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2494  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2495 
2496  // Constants and dimension sizes
2497  auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2498  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2499  Value constH =
2500  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2501  Value constW =
2502  RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2503 
2504  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2505  Value valReal = args[0];
2506  Value valImag = args[1];
2507  Value sumReal = args[2];
2508  Value sumImag = args[3];
2509 
2510  // Indices for angle computation
2511  Value oy = builder.create<linalg::IndexOp>(loc, 1);
2512  Value ox = builder.create<linalg::IndexOp>(loc, 2);
2513  Value iy = builder.create<linalg::IndexOp>(loc, 3);
2514  Value ix = builder.create<linalg::IndexOp>(loc, 4);
2515 
2516  // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2517  // ox) % W ) / W);
2518  auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2519  auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2520 
2521  auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2522  auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2523 
2524  auto iyRemFloat =
2525  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2526  auto ixRemFloat =
2527  RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2528 
2529  auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2530  auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2531 
2532  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2533  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2534 
2535  if (inverse.getValue()) {
2536  angle = builder.create<arith::MulFOp>(
2537  loc, angle,
2538  rewriter.create<arith::ConstantOp>(
2539  loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2540  }
2541 
2542  // realComponent = val_real * cos(a) + val_imag * sin(a);
2543  // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2544  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2545  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2546 
2547  auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2548  auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2549  auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2550 
2551  auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2552  auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2553 
2554  auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2555 
2556  // outReal = sumReal + realComponent
2557  // outImag = sumImag - imagComponent
2558  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2559  auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2560 
2561  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2562  };
2563 
2564  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2565  fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2566  indexingMaps, iteratorTypes, buildBody);
2567 
2568  return success();
2569  }
2570 };
2571 
2572 } // namespace
2573 
2575  const TypeConverter &converter, RewritePatternSet *patterns) {
2576 
2577  // We have multiple resize coverters to handle degenerate cases.
2578  patterns->add<GenericResizeConverter>(patterns->getContext(),
2579  /*benefit=*/100);
2580  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2581  /*benefit=*/200);
2582  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2583  /*benefit=*/300);
2584 
2585  patterns->add<
2586  // clang-format off
2587  PointwiseConverter<tosa::AddOp>,
2588  PointwiseConverter<tosa::SubOp>,
2589  PointwiseConverter<tosa::MulOp>,
2590  PointwiseConverter<tosa::IntDivOp>,
2591  PointwiseConverter<tosa::NegateOp>,
2592  PointwiseConverter<tosa::PowOp>,
2593  PointwiseConverter<tosa::ReciprocalOp>,
2594  PointwiseConverter<tosa::RsqrtOp>,
2595  PointwiseConverter<tosa::LogOp>,
2596  PointwiseConverter<tosa::ExpOp>,
2597  PointwiseConverter<tosa::AbsOp>,
2598  PointwiseConverter<tosa::SinOp>,
2599  PointwiseConverter<tosa::CosOp>,
2600  PointwiseConverter<tosa::TanhOp>,
2601  PointwiseConverter<tosa::ErfOp>,
2602  PointwiseConverter<tosa::BitwiseAndOp>,
2603  PointwiseConverter<tosa::BitwiseOrOp>,
2604  PointwiseConverter<tosa::BitwiseNotOp>,
2605  PointwiseConverter<tosa::BitwiseXorOp>,
2606  PointwiseConverter<tosa::LogicalAndOp>,
2607  PointwiseConverter<tosa::LogicalNotOp>,
2608  PointwiseConverter<tosa::LogicalOrOp>,
2609  PointwiseConverter<tosa::LogicalXorOp>,
2610  PointwiseConverter<tosa::CastOp>,
2611  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2612  PointwiseConverter<tosa::LogicalRightShiftOp>,
2613  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2614  PointwiseConverter<tosa::ClzOp>,
2615  PointwiseConverter<tosa::SelectOp>,
2616  PointwiseConverter<tosa::GreaterOp>,
2617  PointwiseConverter<tosa::GreaterEqualOp>,
2618  PointwiseConverter<tosa::EqualOp>,
2619  PointwiseConverter<tosa::MaximumOp>,
2620  PointwiseConverter<tosa::MinimumOp>,
2621  PointwiseConverter<tosa::CeilOp>,
2622  PointwiseConverter<tosa::FloorOp>,
2623  PointwiseConverter<tosa::ClampOp>,
2624  PointwiseConverter<tosa::SigmoidOp>
2625  >(converter, patterns->getContext());
2626 
2627  patterns->add<
2628  IdentityNConverter<tosa::IdentityOp>,
2629  ReduceConverter<tosa::ReduceAllOp>,
2630  ReduceConverter<tosa::ReduceAnyOp>,
2631  ReduceConverter<tosa::ReduceMinOp>,
2632  ReduceConverter<tosa::ReduceMaxOp>,
2633  ReduceConverter<tosa::ReduceSumOp>,
2634  ReduceConverter<tosa::ReduceProdOp>,
2635  ArgMaxConverter,
2636  GatherConverter,
2637  RescaleConverter,
2638  ReverseConverter,
2639  RFFT2dConverter,
2640  FFT2dConverter,
2641  TableConverter,
2642  TileConverter>(patterns->getContext());
2643  // clang-format on
2644 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, ValueRange operands, int64_t rank)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
const float * table
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:99
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:56
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362