MLIR  16.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
28 
29 #include <numeric>
30 
31 using namespace mlir;
32 using namespace mlir::tosa;
33 
34 template <typename T>
35 static arith::ConstantOp
36 createConstFromIntAttribute(Operation *op, const std::string &attrName,
37  Type requiredAttrType, OpBuilder &rewriter) {
38  auto castedN = static_cast<T>(
39  op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
40  return rewriter.create<arith::ConstantOp>(
41  op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
42 }
43 
44 static Value
46  ArrayRef<Type> resultTypes,
47  PatternRewriter &rewriter) {
48  Location loc = op->getLoc();
49  auto elementTy =
50  op->getOperand(0).getType().cast<ShapedType>().getElementType();
51 
52  // tosa::AbsOp
53  if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
54  return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
55 
56  if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
57  auto zero = rewriter.create<arith::ConstantOp>(
58  loc, rewriter.getZeroAttr(elementTy));
59  auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
60  args[0], zero);
61  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
62  return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
63  }
64 
65  // tosa::AddOp
66  if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
67  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
68 
69  if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
70  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
71 
72  // tosa::SubOp
73  if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
74  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
75 
76  if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
77  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
78 
79  // tosa::MulOp
80  if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
81  if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
82  (void)rewriter.notifyMatchFailure(op,
83  "Cannot have shift value for float");
84  return nullptr;
85  }
86  return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
87  }
88 
89  // tosa::DivOp
90  if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
91  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
92 
93  // tosa::ReciprocalOp
94  if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
95  auto one =
96  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
97  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
98  }
99 
100  if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
101  Value a = args[0];
102  Value b = args[1];
103  auto shift =
104  op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
105  if (shift > 0) {
106  auto shiftConst =
107  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
108  if (!a.getType().isInteger(32))
109  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
110 
111  if (!b.getType().isInteger(32))
112  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
113 
114  auto result = rewriter.create<tosa::ApplyScaleOp>(
115  loc, rewriter.getI32Type(), a, b, shiftConst,
116  rewriter.getBoolAttr(false));
117 
118  if (elementTy.isInteger(32))
119  return result;
120 
121  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
122  }
123 
124  int aWidth = a.getType().getIntOrFloatBitWidth();
125  int bWidth = b.getType().getIntOrFloatBitWidth();
126  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
127 
128  if (aWidth < cWidth)
129  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
130  if (bWidth < cWidth)
131  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
132 
133  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
134  }
135 
136  // tosa::NegateOp
137  if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
138  return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
139 
140  if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
141  !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
142  auto constant =
143  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
144  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
145  }
146 
147  if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
148  cast<tosa::NegateOp>(op).getQuantizationInfo()) {
149  auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
150  int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
151  int64_t inZp = quantizationInfo.value().getInputZp();
152  int64_t outZp = quantizationInfo.value().getOutputZp();
153 
154  // Compute the maximum value that can occur in the intermediate buffer.
155  int64_t zpAdd = inZp + outZp;
156  int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
157  std::abs(zpAdd) + 1;
158 
159  // Convert that maximum value into the maximum bitwidth needed to represent
160  // it. We assume 48-bit numbers may be supported further in the pipeline.
161  int intermediateBitWidth = 64;
162  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
163  intermediateBitWidth = 16;
164  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
165  intermediateBitWidth = 32;
166  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
167  intermediateBitWidth = 48;
168  }
169 
170  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
171  Value zpAddValue = rewriter.create<arith::ConstantOp>(
172  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
173 
174  // The negation can be applied by doing:
175  // outputValue = inZp + outZp - inputValue
176  auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
177  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
178 
179  // Clamp to the negation range.
180  auto min = rewriter.create<arith::ConstantIntOp>(
181  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
182  intermediateType);
183  auto max = rewriter.create<arith::ConstantIntOp>(
184  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
185  intermediateType);
186  auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
187 
188  // Truncate to the final value.
189  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
190  }
191 
192  // tosa::BitwiseAndOp
193  if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
194  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
195 
196  // tosa::BitwiseOrOp
197  if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
198  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
199 
200  // tosa::BitwiseNotOp
201  if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
202  auto allOnesAttr = rewriter.getIntegerAttr(
203  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
204  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
205  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
206  }
207 
208  // tosa::BitwiseXOrOp
209  if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
210  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
211 
212  // tosa::LogicalLeftShiftOp
213  if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
214  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
215 
216  // tosa::LogicalRightShiftOp
217  if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
218  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
219 
220  // tosa::ArithmeticRightShiftOp
221  if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
222  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
223  auto round = op->getAttr("round").cast<BoolAttr>().getValue();
224  if (!round) {
225  return result;
226  }
227 
228  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
229  auto one =
230  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
231  auto zero =
232  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
233  auto i1one =
234  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
235 
236  // Checking that input2 != 0
237  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
238  loc, arith::CmpIPredicate::sgt, args[1], zero);
239 
240  // Checking for the last bit of input1 to be 1
241  auto subtract =
242  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
243  auto shifted =
244  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
245  ->getResults();
246  auto truncated =
247  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, mlir::None);
248  auto isInputOdd =
249  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
250 
251  auto shouldRound = rewriter.create<arith::AndIOp>(
252  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
253  auto extended =
254  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
255  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
256  }
257 
258  // tosa::ClzOp
259  if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
260  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
261  }
262 
263  // tosa::LogicalAnd
264  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
265  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
266 
267  // tosa::LogicalNot
268  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
269  auto one = rewriter.create<arith::ConstantOp>(
270  loc, rewriter.getIntegerAttr(elementTy, 1));
271  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
272  }
273 
274  // tosa::LogicalOr
275  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
276  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
277 
278  // tosa::LogicalXor
279  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
280  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
281 
282  // tosa::PowOp
283  if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
284  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
285 
286  // tosa::RsqrtOp
287  if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
288  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
289 
290  // tosa::LogOp
291  if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
292  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
293 
294  // tosa::ExpOp
295  if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
296  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
297 
298  // tosa::TanhOp
299  if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
300  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
301 
302  // tosa::GreaterOp
303  if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
304  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
305  args[0], args[1]);
306 
307  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
308  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
309  args[0], args[1]);
310 
311  // tosa::GreaterEqualOp
312  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
313  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
314  args[0], args[1]);
315 
316  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
317  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
318  args[0], args[1]);
319 
320  // tosa::EqualOp
321  if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
322  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
323  args[0], args[1]);
324 
325  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
326  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
327  args[0], args[1]);
328 
329  // tosa::SelectOp
330  if (isa<tosa::SelectOp>(op)) {
331  elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
332  if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
333  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
334  }
335 
336  // tosa::MaximumOp
337  if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
338  return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
339  }
340 
341  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
342  auto predicate = rewriter.create<arith::CmpIOp>(
343  loc, arith::CmpIPredicate::sgt, args[0], args[1]);
344  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
345  }
346 
347  // tosa::MinimumOp
348  if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
349  return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
350  }
351 
352  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
353  auto predicate = rewriter.create<arith::CmpIOp>(
354  loc, arith::CmpIPredicate::slt, args[0], args[1]);
355  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
356  }
357 
358  // tosa::CeilOp
359  if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
360  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
361 
362  // tosa::FloorOp
363  if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
364  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
365 
366  // tosa::ClampOp
367  if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
368  bool losesInfo = false;
369  APFloat minApf = op->getAttr("min_fp").cast<FloatAttr>().getValue();
370  APFloat maxApf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
371  minApf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
372  APFloat::rmNearestTiesToEven, &losesInfo);
373  maxApf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
374  APFloat::rmNearestTiesToEven, &losesInfo);
375  auto min = rewriter.create<arith::ConstantOp>(
376  loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
377  auto max = rewriter.create<arith::ConstantOp>(
378  loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
379  return clampFloatHelper(loc, args[0], min, max, rewriter);
380  }
381 
382  if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
383  auto intTy = elementTy.cast<IntegerType>();
384  int32_t min = static_cast<int32_t>(
385  op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
386  int32_t max = static_cast<int32_t>(
387  op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
388 
389  if (intTy.isUnsignedInteger()) {
390  min = std::max<int32_t>(min, 0);
391  max = std::min<int32_t>(
392  max,
393  APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
394  } else {
395  min = std::max<int32_t>(
396  min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
397  .getSExtValue());
398  max = std::min<int32_t>(
399  max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
400  .getSExtValue());
401  }
402 
403  auto minVal = rewriter.create<arith::ConstantIntOp>(
404  loc, min, intTy.getIntOrFloatBitWidth());
405  auto maxVal = rewriter.create<arith::ConstantIntOp>(
406  loc, max, intTy.getIntOrFloatBitWidth());
407  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
408  }
409 
410  // tosa::SigmoidOp
411  if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
412  auto one =
413  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
414  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
415  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
416  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
417  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
418  }
419 
420  // tosa::CastOp
421  if (isa<tosa::CastOp>(op)) {
422  Type srcTy = elementTy;
423  Type dstTy = resultTypes.front();
424  bool bitExtend =
426 
427  if (srcTy == dstTy)
428  return args.front();
429 
430  if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
431  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args, mlir::None);
432 
433  if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
434  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
435  mlir::None);
436 
437  // 1-bit integers need to be treated as signless.
438  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
439  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
440  mlir::None);
441 
442  if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
443  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
444  mlir::None);
445 
446  // Unsigned integers need an unrealized cast so that they can be passed
447  // to UIToFP.
448  if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) {
449  auto unrealizedCast =
450  rewriter
451  .create<UnrealizedConversionCastOp>(
452  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
453  args[0])
454  .getResult(0);
455  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
456  unrealizedCast);
457  }
458 
459  // All other si-to-fp conversions should be handled by SIToFP.
460  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
461  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
462  mlir::None);
463 
464  // Casting to boolean, floats need to only be checked as not-equal to zero.
465  if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
466  Value zero = rewriter.create<arith::ConstantOp>(
467  loc, rewriter.getFloatAttr(srcTy, 0.0));
468  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
469  args.front(), zero);
470  }
471 
472  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
473  auto zero = rewriter.create<arith::ConstantOp>(
474  loc, rewriter.getF32FloatAttr(0.0f));
475  auto half = rewriter.create<arith::ConstantOp>(
476  loc, rewriter.getF32FloatAttr(0.5f));
477 
478  auto intMin = rewriter.create<arith::ConstantOp>(
479  loc, rewriter.getF32FloatAttr(
480  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
481  .getSExtValue()));
482 
483  auto intMax = rewriter.create<arith::ConstantOp>(
484  loc, rewriter.getF32FloatAttr(
485  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
486  .getSExtValue()));
487 
488  auto added = rewriter.create<arith::AddFOp>(loc, args[0], half);
489  auto subbed = rewriter.create<arith::SubFOp>(loc, args[0], half);
490  auto negative = rewriter.create<arith::CmpFOp>(
491  loc, arith::CmpFPredicate::OLT, args[0], zero);
492  auto rounded =
493  rewriter.create<arith::SelectOp>(loc, negative, subbed, added);
494 
495  auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
496 
497  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
498  }
499 
500  // Casting to boolean, integers need to only be checked as not-equal to
501  // zero.
502  if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
503  Value zero = rewriter.create<arith::ConstantIntOp>(
504  loc, 0, srcTy.getIntOrFloatBitWidth());
505  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
506  args.front(), zero);
507  }
508 
509  if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
510  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
511  mlir::None);
512 
513  if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
514  auto intMin = rewriter.create<arith::ConstantIntOp>(
515  loc,
516  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
517  .getSExtValue(),
518  srcTy.getIntOrFloatBitWidth());
519 
520  auto intMax = rewriter.create<arith::ConstantIntOp>(
521  loc,
522  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
523  .getSExtValue(),
524  srcTy.getIntOrFloatBitWidth());
525 
526  auto clamped = clampIntHelper(loc, args[0], intMin, intMax, rewriter);
527  return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped);
528  }
529  }
530 
531  (void)rewriter.notifyMatchFailure(
532  op, "unhandled op for linalg body calculation for elementwise op");
533  return nullptr;
534 }
535 
536 static LogicalResult
538  PatternRewriter &rewriter) {
539  auto loc = operation->getLoc();
540 
541  assert(operation->getNumResults() == 1 &&
542  "All TOSA elementwise ops should only return a single result.");
543 
544  auto results = operation->getResults();
545  auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
546 
547  if (!resultTy)
548  return rewriter.notifyMatchFailure(operation,
549  "All results must be a shaped type");
550 
551  unsigned rank = resultTy.getRank();
552 
553  // Construct the indexing maps needed for linalg.generic ops.
554  SmallVector<Type> bodyArgTypes;
555 
556  for (Value in : operation->getOperands())
557  bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
558 
559  SmallVector<Type> opResultTypes;
560  SmallVector<Value> emptyTensors;
561 
562  SmallVector<Value> dynDims;
563  dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
564 
565  for (auto arg : operation->getOperands()) {
566  auto operandTy = arg.getType().cast<ShapedType>();
567  for (int i = 0; i < operandTy.getRank(); i++) {
568  if (operandTy.isDynamicDim(i) && !dynDims[i])
569  dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
570  }
571  }
572 
573  SmallVector<Value> filteredDims = condenseValues(dynDims);
574 
575  for (auto result : results) {
576  auto resultTy = result.getType().template cast<ShapedType>();
577  emptyTensors.push_back(rewriter.create<tensor::EmptyOp>(
578  loc, resultTy.getShape(), resultTy.getElementType(), filteredDims));
579  opResultTypes.push_back(result.getType());
580  }
581 
582  auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
583  emptyTensors, [](Value v) { return getElementTypeOrSelf(v); }));
584 
585  SmallVector<Value, 2> operands;
586  SmallVector<AffineMap, 2> indexingMaps;
587  indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
588 
589  // Input indexing maps may be broadcasted.
590  for (Value operand : operation->getOperands()) {
591  ShapedType type = operand.getType().cast<ShapedType>();
592 
593  if (type.getShape() == resultTy.getShape()) {
594  operands.push_back(operand);
595  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
596  continue;
597  }
598 
599  SmallVector<int64_t, 5> newShape;
600  SmallVector<AffineExpr, 4> affineExprs;
601  newShape.reserve(type.getRank());
602  for (const auto &it : llvm::enumerate(type.getShape())) {
603  if (it.value() == resultTy.getDimSize(it.index())) {
604  newShape.push_back(it.value());
605  affineExprs.push_back(
606  mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
607  }
608  }
609 
610  if (newShape.size() != rank) {
611  operand = rewriter.create<tosa::ReshapeOp>(
612  loc, RankedTensorType::get(newShape, type.getElementType()), operand,
613  rewriter.getI64ArrayAttr(newShape));
614  }
615 
616  operands.push_back(operand);
617  indexingMaps.push_back(AffineMap::get(
618  /*dimCount=*/rank, /*symbolCount=*/0, affineExprs,
619  rewriter.getContext()));
620  }
621 
622  indexingMaps.append(operation->getNumResults(),
623  rewriter.getMultiDimIdentityMap(rank));
624 
625  bool didEncounterError = false;
626  auto linalgOp = rewriter.create<linalg::GenericOp>(
627  loc, opResultTypes, operands, emptyTensors, indexingMaps,
629  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
631  operation, blockArgs.take_front(operation->getNumOperands()),
632  bodyResultTypes, rewriter);
633  if (!opResult) {
634  didEncounterError = true;
635  return;
636  }
637  nestedBuilder.create<linalg::YieldOp>(loc, opResult);
638  });
639 
640  if (didEncounterError)
641  return failure();
642 
643  rewriter.replaceOp(operation, linalgOp->getResults());
644  return success();
645 }
646 
647 // Returns the constant initial value for a given reduction operation. The
648 // attribute type varies depending on the element type required.
650  PatternRewriter &rewriter) {
651  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
652  return rewriter.getFloatAttr(elementTy, 0.0);
653 
654  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
655  return rewriter.getIntegerAttr(elementTy, 0);
656 
657  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
658  return rewriter.getFloatAttr(elementTy, 1.0);
659 
660  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
661  return rewriter.getIntegerAttr(elementTy, 1);
662 
663  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
664  return rewriter.getFloatAttr(
665  elementTy, APFloat::getLargest(
666  elementTy.cast<FloatType>().getFloatSemantics(), false));
667 
668  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
669  return rewriter.getIntegerAttr(
670  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
671 
672  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
673  return rewriter.getFloatAttr(
674  elementTy, APFloat::getLargest(
675  elementTy.cast<FloatType>().getFloatSemantics(), true));
676 
677  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
678  return rewriter.getIntegerAttr(
679  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
680 
681  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
682  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
683 
684  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
685  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
686 
687  if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
688  return rewriter.getFloatAttr(
689  elementTy, APFloat::getLargest(
690  elementTy.cast<FloatType>().getFloatSemantics(), true));
691 
692  if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
693  return rewriter.getIntegerAttr(
694  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
695 
696  return {};
697 }
698 
699 // Creates the body calculation for a reduction. The operations vary depending
700 // on the input type.
702  ValueRange args,
703  Type elementTy,
704  PatternRewriter &rewriter) {
705  Location loc = op->getLoc();
706  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
707  return rewriter.create<arith::AddFOp>(loc, args);
708  }
709 
710  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
711  return rewriter.create<arith::AddIOp>(loc, args);
712  }
713 
714  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
715  return rewriter.create<arith::MulFOp>(loc, args);
716  }
717 
718  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
719  return rewriter.create<arith::MulIOp>(loc, args);
720  }
721 
722  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
723  return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
724  }
725 
726  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
727  auto predicate = rewriter.create<arith::CmpIOp>(
728  loc, arith::CmpIPredicate::slt, args[0], args[1]);
729  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
730  }
731 
732  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
733  return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
734  }
735 
736  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
737  auto predicate = rewriter.create<arith::CmpIOp>(
738  loc, arith::CmpIPredicate::sgt, args[0], args[1]);
739  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
740  }
741 
742  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
743  return rewriter.create<arith::AndIOp>(loc, args);
744 
745  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
746  return rewriter.create<arith::OrIOp>(loc, args);
747 
748  return {};
749 }
750 
751 // Performs the match and rewrite for reduction operations. This includes
752 // declaring a correctly sized initial value, and the linalg.generic operation
753 // that reduces across the specified axis.
755  PatternRewriter &rewriter) {
756  auto loc = op->getLoc();
757  auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
758  auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
759  auto elementTy = resultTy.getElementType();
760  Value input = op->getOperand(0);
761 
762  llvm::SmallVector<int64_t> reduceShape;
763  SmallVector<Value> dynDims;
764  for (unsigned i = 0; i < inputTy.getRank(); i++) {
765  if (axis != i) {
766  reduceShape.push_back(inputTy.getDimSize(i));
767  if (inputTy.isDynamicDim(i))
768  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
769  }
770  }
771 
772  Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
773 
774  // First fill the output buffer with the init value.
775  auto emptyTensor =
776  rewriter
777  .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
778  dynDims)
779  .getResult();
780 
781  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
782  if (!fillValueAttr)
783  return rewriter.notifyMatchFailure(
784  op, "No initial value found for reduction operation");
785 
786  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
787  auto filledTensor = rewriter
788  .create<linalg::FillOp>(loc, ValueRange{fillValue},
789  ValueRange{emptyTensor})
790  .result();
791 
795  for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
796  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
797 
798  iteratorTypes.push_back(axis == i ? utils::IteratorType::reduction
799  : utils::IteratorType::parallel);
800  if (axis != i)
801  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
802  }
803 
804  bool didEncounterError = false;
805  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
806  auto linalgOp = rewriter.create<linalg::GenericOp>(
807  loc, reduceTy, input, filledTensor, maps, iteratorTypes,
808  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
810  op, blockArgs, elementTy, rewriter);
811  if (result)
812  didEncounterError = true;
813 
814  nestedBuilder.create<linalg::YieldOp>(loc, result);
815  });
816 
817  if (!didEncounterError)
818  return failure();
819 
820  SmallVector<ReassociationExprs, 4> reassociationMap;
821  uint64_t expandInputRank =
822  linalgOp.getResults()[0].getType().cast<ShapedType>().getRank();
823  reassociationMap.resize(expandInputRank);
824 
825  for (uint64_t i = 0; i < expandInputRank; i++) {
826  int32_t dimToPush = i > axis ? i + 1 : i;
827  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
828  }
829 
830  if (expandInputRank != 0) {
831  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
832  reassociationMap[expandedDim].push_back(
833  rewriter.getAffineDimExpr(expandedDim + 1));
834  }
835 
836  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
837  op, resultTy, linalgOp.getResults()[0], reassociationMap);
838  return success();
839 }
840 
842  ArrayRef<int64_t> rhsShape,
843  SmallVector<int64_t> &intermediateShape,
844  bool isDynamic) {
845  if (isDynamic) {
846  // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
847  intermediateShape = {ShapedType::kDynamic};
848  return true;
849  }
850 
851  if (lhsShape.empty() || rhsShape.empty()) {
852  intermediateShape = {};
853  return true;
854  }
855 
856  unsigned currLhsDim = 0, currRhsDim = 0;
857  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
858  int64_t rhsSize = rhsShape[currRhsDim];
859  int64_t lhsSize = lhsShape[currLhsDim];
860  while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
861  currRhsDim < rhsShape.size()) {
862  if (lhsSize < rhsSize) {
863  currLhsDim++;
864  lhsSize *= lhsShape[currLhsDim];
865  } else {
866  currRhsDim++;
867  rhsSize *= rhsShape[currRhsDim];
868  }
869  }
870  if (lhsSize == rhsSize) {
871  intermediateShape.push_back(lhsSize);
872  }
873  currRhsDim++;
874  currLhsDim++;
875  }
876 
877  // If the iterators didn't reach the end and their leftover dimensions are not
878  // equal to 1 an intermediate shape was not found.
879  while (currLhsDim < lhsShape.size()) {
880  if (lhsShape[currLhsDim++] != 1) {
881  return false;
882  }
883  }
884 
885  while (currRhsDim < rhsShape.size()) {
886  if (rhsShape[currRhsDim++] != 1) {
887  return false;
888  }
889  }
890 
891  return true;
892 }
893 
895  PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
896  ArrayRef<int64_t> dstShape,
897  SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
898 
899  // If the shape is dynamic, create a map for collapsing into one dimension.
900  if (isDynamic) {
902  for (int i = 0, s = srcShape.size(); i < s; ++i)
903  exprs.push_back(rewriter.getAffineDimExpr(i));
904  reassociationMap = {exprs};
905  return true;
906  }
907 
908  if (dstShape.empty()) {
909  reassociationMap = {};
910  return true;
911  }
912 
913  reassociationMap.resize(dstShape.size());
914  unsigned currSrcDim = 0, currDstDim = 0;
915  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
916  int64_t dstSize = dstShape[currDstDim];
917  int64_t srcSize = srcShape[currSrcDim];
918  while (srcSize < dstSize && currSrcDim < srcShape.size()) {
919  reassociationMap[currDstDim].push_back(
920  rewriter.getAffineDimExpr(currSrcDim++));
921  srcSize *= srcShape[currSrcDim];
922  }
923  if (srcSize == dstSize) {
924  reassociationMap[currDstDim].push_back(
925  rewriter.getAffineDimExpr(currSrcDim++));
926  // If the next dim in collapsedShape is not 1, treat subsequent dims in
927  // expandedShape which are 1 to be collapsed.
928  if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
929  while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
930  reassociationMap[currDstDim].push_back(
931  rewriter.getAffineDimExpr(currSrcDim++));
932  }
933  }
934  }
935  currDstDim++;
936  }
937 
938  // If both iterators didn't reach the end, we have leftover dimentions which
939  // implies that we have a mismatch in shape.
940  return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
941 }
942 
943 namespace {
944 
945 template <typename SrcOp>
946 class PointwiseConverter : public OpRewritePattern<SrcOp> {
947 public:
949 
950  LogicalResult matchAndRewrite(SrcOp op,
951  PatternRewriter &rewriter) const final {
952  return elementwiseMatchAndRewriteHelper(op, rewriter);
953  }
954 };
955 
956 class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
957 public:
959 
961  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
962  ConversionPatternRewriter &rewriter) const final {
963  ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
964  ShapedType resultTy = reshape.getType().template cast<ShapedType>();
965  bool isDynamic = !operandTy.hasStaticShape();
966 
967  if (isDynamic && resultTy.getRank() != 1) {
968  return rewriter.notifyMatchFailure(
969  reshape, "Cannot collapse dynamic dims to more than one dimension");
970  }
971 
972  if (operandTy == resultTy) {
973  rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
974  return success();
975  }
976 
977  SmallVector<ReassociationExprs, 4> reassociationMap;
978  if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
979  resultTy.getShape(),
980  reassociationMap, isDynamic)) {
981  return rewriter.notifyMatchFailure(
982  reshape,
983  "tosa.reshape Attempting to collapse into an incompatible shape");
984  }
985 
986  SmallVector<int64_t> intermediateShape;
987  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
988  intermediateShape, isDynamic)) {
989  return rewriter.notifyMatchFailure(
990  reshape, "tosa.reshape Cannot collapse into given shape");
991  }
992 
993  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
994  reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
995  return success();
996  }
997 };
998 
999 class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
1000 public:
1002 
1004  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1005  ConversionPatternRewriter &rewriter) const final {
1006  ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
1007  ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1008  bool isDynamic = !operandTy.hasStaticShape();
1009 
1010  if (operandTy == resultTy) {
1011  rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1012  return success();
1013  }
1014 
1015  if (isDynamic && operandTy.getRank() != 1) {
1016  return rewriter.notifyMatchFailure(
1017  reshape, "Cannot expand dynamic dims from more than one dimension");
1018  }
1019 
1020  SmallVector<ReassociationExprs, 4> reassociationMap;
1021  if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
1022  operandTy.getShape(),
1023  reassociationMap, isDynamic)) {
1024  return rewriter.notifyMatchFailure(
1025  reshape,
1026  "tosa.reshape Attempting to expand into an incompatible shape");
1027  }
1028 
1029  SmallVector<int64_t> intermediateShape;
1030  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
1031  intermediateShape, isDynamic) ||
1032  intermediateShape != operandTy.getShape()) {
1033  return rewriter.notifyMatchFailure(
1034  reshape, "tosa.reshape Cannot expand into given shape");
1035  }
1036  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1037  reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1038  return success();
1039  }
1040 };
1041 
1042 class ReshapeConverterCollapseExpand
1043  : public OpConversionPattern<tosa::ReshapeOp> {
1044 public:
1046 
1048  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1049  ConversionPatternRewriter &rewriter) const final {
1050  ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
1051  ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1052  bool isDynamic = !operandTy.hasStaticShape();
1053 
1054  if (operandTy == resultTy) {
1055  rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1056  return success();
1057  }
1058 
1059  SmallVector<int64_t> intermediateShape;
1060  if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
1061  intermediateShape, isDynamic)) {
1062  return rewriter.notifyMatchFailure(
1063  reshape, "tosa.reshape Cannot identify an intermediate shape between "
1064  "the given two shapes");
1065  }
1066 
1067  Value collapse = rewriter.create<tosa::ReshapeOp>(
1068  reshape.getLoc(),
1069  RankedTensorType::get(intermediateShape,
1070  reshape.getType().getElementType()),
1071  adaptor.getInput1());
1072  Value expand =
1073  rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
1074  rewriter.replaceOp(reshape, expand);
1075 
1076  return success();
1077  }
1078 };
1079 
1080 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1081 public:
1083 
1084  LogicalResult matchAndRewrite(tosa::TransposeOp op,
1085  PatternRewriter &rewriter) const final {
1086  DenseIntElementsAttr perms;
1087  if (!matchPattern(op.getPerms(), m_Constant(&perms))) {
1088  return failure();
1089  }
1090 
1091  auto loc = op.getLoc();
1092  auto input = op->getOperand(0);
1093  auto resultTy = op.getType().cast<ShapedType>();
1094 
1095  SmallVector<Value> dynDims;
1096  dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1097 
1098  SmallVector<AffineExpr, 2> inputExprs;
1099  inputExprs.resize(resultTy.getRank());
1100  auto operandTy = input.getType().cast<ShapedType>();
1101  for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
1102  auto index = permutation.index();
1103  auto value = permutation.value().getZExtValue();
1104  if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1105  dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
1106  }
1107  inputExprs[value] = rewriter.getAffineDimExpr(index);
1108  }
1109 
1110  SmallVector<Value> filteredDims = condenseValues(dynDims);
1111 
1112  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1113  loc, resultTy.getShape(), resultTy.getElementType(), filteredDims);
1114 
1115  SmallVector<AffineMap, 2> affineMaps = {
1116  AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1117  rewriter.getContext()),
1118  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1119 
1120  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1121  op, resultTy, op.getInput1(), ValueRange{emptyTensor}, affineMaps,
1122  getNParallelLoopsAttrs(resultTy.getRank()),
1123  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1124  nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
1125  });
1126  return success();
1127  }
1128 };
1129 
1130 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1131 public:
1133 
1134  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1135  PatternRewriter &rewriter) const final {
1136  auto loc = op.getLoc();
1137  auto input = op.getInput();
1138  auto inputTy = op.getInput().getType().cast<ShapedType>();
1139  auto outputTy = op.getOutput().getType().cast<ShapedType>();
1140  unsigned rank = inputTy.getRank();
1141 
1142  // This is an illegal configuration. terminate and log an error
1143  if (op.getDoubleRound() && !op.getScale32())
1144  return rewriter.notifyMatchFailure(
1145  op, "tosa.rescale requires scale32 for double_round to be true");
1146 
1147  SmallVector<Value> dynDims;
1148  for (int i = 0; i < outputTy.getRank(); i++) {
1149  if (outputTy.isDynamicDim(i)) {
1150  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1151  }
1152  }
1153 
1154  // The shift and multiplier values.
1155  SmallVector<int32_t> multiplierValues;
1156  getValuesFromIntArrayAttribute(op.getMultiplier(), multiplierValues);
1157 
1158  SmallVector<int8_t> shiftValues;
1159  getValuesFromIntArrayAttribute(op.getShift(), shiftValues);
1160 
1161  // If we shift by more than the bitwidth, this just sets to 0.
1162  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1163  if (shiftValues[i] > 63) {
1164  shiftValues[i] = 0;
1165  multiplierValues[i] = 0;
1166  }
1167  }
1168 
1169  // Double round only occurs if shift is greater than 31, check that this
1170  // is ever true.
1171  bool doubleRound =
1172  op.getDoubleRound() &&
1173  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1174 
1175  SmallVector<AffineMap> indexingMaps = {
1176  rewriter.getMultiDimIdentityMap(rank)};
1177  SmallVector<Value, 4> genericInputs = {input};
1178 
1179  // If we are rescaling per-channel then we need to store the multiplier
1180  // values in a buffer.
1181  Value multiplierConstant;
1182  int64_t multiplierArg = 0;
1183  if (multiplierValues.size() == 1) {
1184  multiplierConstant = rewriter.create<arith::ConstantOp>(
1185  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1186  } else {
1187  SmallVector<AffineExpr, 2> multiplierExprs{
1188  rewriter.getAffineDimExpr(rank - 1)};
1189  auto multiplierType =
1190  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1191  rewriter.getI32Type());
1192  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1193  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1194 
1195  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1196  /*symbolCount=*/0, multiplierExprs,
1197  rewriter.getContext()));
1198 
1199  multiplierArg = indexingMaps.size() - 1;
1200  }
1201 
1202  // If we are rescaling per-channel then we need to store the shift
1203  // values in a buffer.
1204  Value shiftConstant;
1205  int64_t shiftArg = 0;
1206  if (shiftValues.size() == 1) {
1207  shiftConstant = rewriter.create<arith::ConstantOp>(
1208  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1209  } else {
1210  SmallVector<AffineExpr, 2> shiftExprs = {
1211  rewriter.getAffineDimExpr(rank - 1)};
1212  auto shiftType =
1213  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1214  rewriter.getIntegerType(8));
1215  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1216  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1217  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1218  /*symbolCount=*/0, shiftExprs,
1219  rewriter.getContext()));
1220  shiftArg = indexingMaps.size() - 1;
1221  }
1222 
1223  // Indexing maps for output values.
1224  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1225 
1226  // Construct the indexing maps needed for linalg.generic ops.
1227  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1228  loc, outputTy.getShape(), outputTy.getElementType(),
1229  ArrayRef<Value>({dynDims}));
1230 
1231  auto linalgOp = rewriter.create<linalg::GenericOp>(
1232  loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1233  getNParallelLoopsAttrs(rank),
1234  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1235  ValueRange blockArgs) {
1236  Value value = blockArgs[0];
1237  Type valueTy = value.getType();
1238 
1239  // For now we do all of our math in 64-bit. This is not optimal but
1240  // should be correct for now, consider computing correct bit depth
1241  // later.
1242  int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1243 
1244  auto inputZp = createConstFromIntAttribute<int32_t>(
1245  op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1246  nestedBuilder);
1247  auto outputZp = createConstFromIntAttribute<int32_t>(
1248  op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1249 
1250  Value multiplier = multiplierConstant ? multiplierConstant
1251  : blockArgs[multiplierArg];
1252  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1253 
1254  if (valueTy.getIntOrFloatBitWidth() < 32) {
1255  if (valueTy.isUnsignedInteger()) {
1256  value = nestedBuilder
1257  .create<UnrealizedConversionCastOp>(
1258  nestedLoc,
1259  nestedBuilder.getIntegerType(
1260  valueTy.getIntOrFloatBitWidth()),
1261  value)
1262  .getResult(0);
1263  value = nestedBuilder.create<arith::ExtUIOp>(
1264  nestedLoc, nestedBuilder.getI32Type(), value);
1265  } else {
1266  value = nestedBuilder.create<arith::ExtSIOp>(
1267  nestedLoc, nestedBuilder.getI32Type(), value);
1268  }
1269  }
1270 
1271  value =
1272  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1273 
1274  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1275  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1276  nestedBuilder.getBoolAttr(doubleRound));
1277 
1278  // Move to the new zero-point.
1279  value =
1280  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1281 
1282  // Saturate to the output size.
1283  IntegerType outIntType =
1284  blockArgs.back().getType().cast<IntegerType>();
1285  unsigned outBitWidth = outIntType.getWidth();
1286 
1287  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1288  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1289 
1290  // Unsigned integers have a difference output value.
1291  if (outIntType.isUnsignedInteger()) {
1292  intMin = 0;
1293  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1294  }
1295 
1296  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1297  loc, nestedBuilder.getI32IntegerAttr(intMin));
1298  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1299  loc, nestedBuilder.getI32IntegerAttr(intMax));
1300 
1301  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1302  nestedBuilder);
1303 
1304  if (outIntType.getWidth() < 32) {
1305  value = nestedBuilder.create<arith::TruncIOp>(
1306  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1307  value);
1308 
1309  if (outIntType.isUnsignedInteger()) {
1310  value = nestedBuilder
1311  .create<UnrealizedConversionCastOp>(nestedLoc,
1312  outIntType, value)
1313  .getResult(0);
1314  }
1315  }
1316 
1317  nestedBuilder.create<linalg::YieldOp>(loc, value);
1318  });
1319 
1320  rewriter.replaceOp(op, linalgOp->getResults());
1321  return success();
1322  }
1323 };
1324 
1325 // Handle the case where the resize operation is a regular broadcast. We
1326 // perform this part separately to avoid generating Extract operations which
1327 // are difficult to vectorize / optimize.
1328 class BroadcastResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1329 public:
1331 
1332  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1333  PatternRewriter &rewriter) const final {
1334  Location loc = op.getLoc();
1335  ImplicitLocOpBuilder builder(loc, rewriter);
1336  auto input = op.getInput();
1337  auto inputTy = input.getType().cast<RankedTensorType>();
1338  auto resultTy = op.getType().cast<RankedTensorType>();
1339 
1340  auto imageH = inputTy.getDimSize(1);
1341  auto imageW = inputTy.getDimSize(2);
1342 
1343  if (imageH != 1 || imageW != 1) {
1344  return rewriter.notifyMatchFailure(
1345  op, "tosa.resize is not a pure broadcast operation");
1346  }
1347 
1348  // TODO(suderman): These string values should be declared the TOSA dialect.
1349  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1350  return failure();
1351 
1352  const bool isBilinear = op.getMode() == "BILINEAR";
1353 
1354  SmallVector<int32_t> scale;
1355  getValuesFromIntArrayAttribute(op.getScale(), scale);
1356 
1357  // Collapse the 1 dimensions away.
1358  SmallVector<ReassociationExprs, 4> collapseMap(2);
1359  collapseMap[0].push_back(builder.getAffineDimExpr(0));
1360  collapseMap[1].push_back(builder.getAffineDimExpr(1));
1361  collapseMap[1].push_back(builder.getAffineDimExpr(2));
1362  collapseMap[1].push_back(builder.getAffineDimExpr(3));
1363 
1364  auto collapseTy =
1365  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1366  inputTy.getElementType());
1367  Value collapse =
1368  builder.create<tensor::CollapseShapeOp>(collapseTy, input, collapseMap);
1369 
1370  // Broadcast input to the output shape.
1371  llvm::SmallVector<Value> outputDynSize;
1372  if (inputTy.isDynamicDim(0))
1373  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1374 
1375  if (inputTy.isDynamicDim(3))
1376  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1377 
1378  llvm::SmallVector<AffineExpr> inputExprs{
1379  rewriter.getAffineDimExpr(0),
1380  rewriter.getAffineDimExpr(3),
1381  };
1382 
1383  auto inputMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
1384  inputExprs, builder.getContext());
1385  auto resultMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1387  utils::IteratorType::parallel);
1388 
1389  Value empty = builder.create<tensor::EmptyOp>(
1390  resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1391 
1392  auto generic = builder.create<linalg::GenericOp>(
1393  resultTy, ValueRange{collapse}, ValueRange{empty},
1394  ArrayRef<AffineMap>{inputMap, resultMap}, iterators,
1395  [=](OpBuilder &b, Location loc, ValueRange args) {
1396  Value value = args[0];
1397  // This is the quantized case.
1398  if (inputTy.getElementType() != resultTy.getElementType()) {
1399  value =
1400  b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1401 
1402  if (isBilinear && scale[0] != 0) {
1403  Value scaleY = b.create<arith::ConstantOp>(
1404  loc, b.getI32IntegerAttr(scale[0]));
1405  value = b.create<arith::MulIOp>(loc, value, scaleY);
1406  }
1407 
1408  if (isBilinear && scale[2] != 0) {
1409  Value scaleX = b.create<arith::ConstantOp>(
1410  loc, b.getI32IntegerAttr(scale[2]));
1411  value = b.create<arith::MulIOp>(loc, value, scaleX);
1412  }
1413  }
1414 
1415  b.create<linalg::YieldOp>(loc, value);
1416  });
1417 
1418  rewriter.replaceOp(op, generic.getResult(0));
1419  return success();
1420  }
1421 };
1422 
1423 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1424 public:
1426 
1427  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1428  PatternRewriter &rewriter) const final {
1429  Location loc = op.getLoc();
1430  auto input = op.getInput();
1431  auto inputTy = input.getType().cast<ShapedType>();
1432  auto resultTy = op.getType().cast<ShapedType>();
1433  auto resultElementTy = resultTy.getElementType();
1434 
1435  auto imageH = inputTy.getShape()[1];
1436  auto imageW = inputTy.getShape()[2];
1437 
1438  auto dynamicDimsOr =
1439  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1440  if (!dynamicDimsOr.has_value())
1441  return failure();
1442  SmallVector<Value> dynamicDims = dynamicDimsOr.value();
1443 
1444  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1445  return failure();
1446 
1447  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1448  loc, resultTy.getShape(), resultElementTy, dynamicDims);
1449 
1450  SmallVector<AffineMap, 2> affineMaps = {
1451  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1452 
1453  Value resize = input;
1454  auto genericOp = rewriter.create<linalg::GenericOp>(
1455  loc, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1456  getNParallelLoopsAttrs(resultTy.getRank()));
1457  resize = genericOp.getResult(0);
1458 
1459  OpBuilder::InsertionGuard regionGuard(rewriter);
1460  rewriter.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1461  TypeRange({resultElementTy}), loc);
1462  Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
1463  Value y = rewriter.create<linalg::IndexOp>(loc, 1);
1464  Value x = rewriter.create<linalg::IndexOp>(loc, 2);
1465  Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
1466 
1467  auto hwMin =
1468  rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1469  auto hMax = rewriter.create<arith::ConstantOp>(
1470  loc, rewriter.getI32IntegerAttr(imageH - 1));
1471  auto wMax = rewriter.create<arith::ConstantOp>(
1472  loc, rewriter.getI32IntegerAttr(imageW - 1));
1473 
1474  Value inY =
1475  rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), y);
1476  Value inX =
1477  rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
1478 
1479  bool floatingPointMode = resultElementTy.isF32();
1480 
1481  Value yScaleN, yScaleD, xScaleN, xScaleD, yOffset, xOffset, yBorder,
1482  xBorder;
1483  SmallVector<int32_t> scale, offset, border;
1484  getValuesFromIntArrayAttribute(op.getScale(), scale);
1485  getValuesFromIntArrayAttribute(op.getOffset(), offset);
1486  getValuesFromIntArrayAttribute(op.getBorder(), border);
1487 
1488  yScaleN = rewriter.create<arith::ConstantOp>(
1489  loc, rewriter.getI32IntegerAttr(scale[0]));
1490  yScaleD = rewriter.create<arith::ConstantOp>(
1491  loc, rewriter.getI32IntegerAttr(scale[1]));
1492  xScaleN = rewriter.create<arith::ConstantOp>(
1493  loc, rewriter.getI32IntegerAttr(scale[2]));
1494  xScaleD = rewriter.create<arith::ConstantOp>(
1495  loc, rewriter.getI32IntegerAttr(scale[3]));
1496  yOffset = rewriter.create<arith::ConstantOp>(
1497  loc, rewriter.getI32IntegerAttr(offset[0]));
1498  xOffset = rewriter.create<arith::ConstantOp>(
1499  loc, rewriter.getI32IntegerAttr(offset[1]));
1500  yBorder = rewriter.create<arith::ConstantOp>(
1501  loc, rewriter.getI32IntegerAttr(border[0]));
1502  xBorder = rewriter.create<arith::ConstantOp>(
1503  loc, rewriter.getI32IntegerAttr(border[1]));
1504 
1505  // Compute the the integer index and partial offset.
1506  Value ix, iy, dx, dy;
1507  // x = x * scale_d + offset;
1508  // ix = floor(x / scale_n)
1509  if (floatingPointMode) {
1510  // dx = x / scale_n - ix
1511  Value y =
1512  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY);
1513  Value x =
1514  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX);
1515 
1516  yScaleN =
1517  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), yScaleN);
1518  yScaleD =
1519  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), yScaleD);
1520  xScaleN =
1521  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), xScaleN);
1522  xScaleD =
1523  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), xScaleD);
1524  yOffset =
1525  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), yOffset);
1526  xOffset =
1527  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), xOffset);
1528 
1529  y = rewriter.create<arith::MulFOp>(loc, y, yScaleD);
1530  x = rewriter.create<arith::MulFOp>(loc, x, xScaleD);
1531 
1532  y = rewriter.create<arith::AddFOp>(loc, y, yOffset);
1533  x = rewriter.create<arith::AddFOp>(loc, x, xOffset);
1534 
1535  y = rewriter.create<arith::DivFOp>(loc, y, yScaleN);
1536  x = rewriter.create<arith::DivFOp>(loc, x, xScaleN);
1537 
1538  iy = rewriter.create<math::FloorOp>(loc, y);
1539  ix = rewriter.create<math::FloorOp>(loc, x);
1540 
1541  dy = rewriter.create<arith::SubFOp>(loc, y, iy);
1542  dx = rewriter.create<arith::SubFOp>(loc, x, ix);
1543 
1544  iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy);
1545  ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix);
1546  } else {
1547  // dx = x - ix * scale_n;
1548  Value y = rewriter.create<arith::MulIOp>(loc, inY, yScaleD);
1549  Value x = rewriter.create<arith::MulIOp>(loc, inX, xScaleD);
1550 
1551  y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
1552  x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
1553 
1554  iy = rewriter.create<arith::DivUIOp>(loc, y, yScaleN);
1555  ix = rewriter.create<arith::DivUIOp>(loc, x, xScaleN);
1556 
1557  Value tempY = rewriter.create<arith::MulIOp>(loc, iy, yScaleN);
1558  Value tempX = rewriter.create<arith::MulIOp>(loc, ix, xScaleN);
1559 
1560  dy = rewriter.create<arith::SubIOp>(loc, y, tempY);
1561  dx = rewriter.create<arith::SubIOp>(loc, x, tempX);
1562  }
1563 
1564  if (op.getMode() == "NEAREST_NEIGHBOR") {
1565  Value yPred, xPred;
1566  auto zeroVal = rewriter.create<arith::ConstantOp>(
1567  loc, rewriter.getI32IntegerAttr(0));
1568  auto oneVal = rewriter.create<arith::ConstantOp>(
1569  loc, rewriter.getI32IntegerAttr(1));
1570 
1571  // Round the index position towards the closest pixel location.
1572  if (floatingPointMode) {
1573  auto halfVal = rewriter.create<arith::ConstantOp>(
1574  loc, rewriter.getF32FloatAttr(0.5f));
1575  yPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1576  dy, halfVal);
1577  xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1578  dx, halfVal);
1579  } else {
1580  Value yScaleNHalfVal =
1581  rewriter.create<arith::ShRSIOp>(loc, yScaleN, oneVal);
1582  Value xScaleNHalfVal =
1583  rewriter.create<arith::ShRSIOp>(loc, xScaleN, oneVal);
1584  yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1585  dy, yScaleNHalfVal);
1586  xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1587  dx, xScaleNHalfVal);
1588  }
1589 
1590  auto yOffset =
1591  rewriter.create<arith::SelectOp>(loc, yPred, oneVal, zeroVal);
1592  auto xOffset =
1593  rewriter.create<arith::SelectOp>(loc, xPred, oneVal, zeroVal);
1594 
1595  iy = rewriter.create<arith::AddIOp>(loc, iy, yOffset);
1596  ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);
1597 
1598  // Clamp the to be within the bounds of the input image.
1599  iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter);
1600  ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter);
1601 
1602  // Read the value from the input array.
1603  iy =
1604  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), iy);
1605  ix =
1606  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), ix);
1607 
1608  Value result = rewriter.create<tensor::ExtractOp>(
1609  loc, input, ValueRange{batch, iy, ix, channel});
1610 
1611  rewriter.create<linalg::YieldOp>(loc, result);
1612  } else {
1613  // The mode here must be BILINEAR.
1614  assert(op.getMode() == "BILINEAR");
1615  Value y0 = iy;
1616  Value x0 = ix;
1617 
1618  auto oneVal = rewriter.create<arith::ConstantOp>(
1619  loc, rewriter.getI32IntegerAttr(1));
1620  Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
1621  Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
1622 
1623  y0 = clampIntHelper(loc, y0, hwMin, hMax, rewriter);
1624  y1 = clampIntHelper(loc, y1, hwMin, hMax, rewriter);
1625 
1626  x0 = clampIntHelper(loc, x0, hwMin, wMax, rewriter);
1627  x1 = clampIntHelper(loc, x1, hwMin, wMax, rewriter);
1628 
1629  y0 =
1630  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);
1631  y1 =
1632  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y1);
1633  x0 =
1634  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x0);
1635  x1 =
1636  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x1);
1637 
1638  Value y0x0 = rewriter.create<tensor::ExtractOp>(
1639  loc, input, ValueRange{batch, y0, x0, channel});
1640  Value y0x1 = rewriter.create<tensor::ExtractOp>(
1641  loc, input, ValueRange{batch, y0, x1, channel});
1642  Value y1x0 = rewriter.create<tensor::ExtractOp>(
1643  loc, input, ValueRange{batch, y1, x0, channel});
1644  Value y1x1 = rewriter.create<tensor::ExtractOp>(
1645  loc, input, ValueRange{batch, y1, x1, channel});
1646 
1647  if (floatingPointMode) {
1648  Value rightPart = dx;
1649  auto oneVal = rewriter.create<arith::ConstantOp>(
1650  loc, rewriter.getF32FloatAttr(1.0f));
1651  Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
1652 
1653  y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
1654  y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
1655  Value topAcc = rewriter.create<arith::AddFOp>(loc, y0x0, y0x1);
1656 
1657  y1x0 = rewriter.create<arith::MulFOp>(loc, y1x0, leftPart);
1658  y1x1 = rewriter.create<arith::MulFOp>(loc, y1x1, rightPart);
1659  Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
1660 
1661  Value bottomPart = dy;
1662  Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
1663  topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
1664  bottomAcc = rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
1665  Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
1666 
1667  rewriter.create<linalg::YieldOp>(loc, result);
1668  } else {
1669  // Perform in quantized space.
1670  y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
1671  y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
1672  y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
1673  y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
1674 
1675  if (resultElementTy.getIntOrFloatBitWidth() > 32) {
1676  dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
1677  dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
1678  }
1679 
1680  Value topAcc, bottomAcc;
1681  if (imageW == 1) {
1682  topAcc = rewriter.create<arith::MulIOp>(loc, y0x0, xScaleN);
1683  bottomAcc = rewriter.create<arith::MulIOp>(loc, y1x0, xScaleN);
1684  } else {
1685  Value rightPart = dx;
1686  Value leftPart = rewriter.create<arith::SubIOp>(loc, xScaleN, dx);
1687 
1688  y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
1689  y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
1690  topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
1691 
1692  y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
1693  y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
1694  bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
1695  }
1696 
1697  Value result;
1698  if (imageH == 1) {
1699  result = rewriter.create<arith::MulIOp>(loc, topAcc, yScaleN);
1700  } else {
1701  Value bottomPart = dy;
1702  Value topPart = rewriter.create<arith::SubIOp>(loc, yScaleN, dy);
1703  topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
1704  bottomAcc =
1705  rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
1706  result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
1707  }
1708 
1709  rewriter.create<linalg::YieldOp>(loc, result);
1710  }
1711  }
1712 
1713  rewriter.replaceOp(op, resize);
1714  return success();
1715  }
1716 };
1717 
1718 // At the codegen level any identity operations should be removed. Any cases
1719 // where identity is load-bearing (e.g. cross device computation) should be
1720 // handled before lowering to codegen.
1721 template <typename SrcOp>
1722 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1723 public:
1725 
1726  LogicalResult matchAndRewrite(SrcOp op,
1727  PatternRewriter &rewriter) const final {
1728  rewriter.replaceOp(op, op.getOperation()->getOperands());
1729  return success();
1730  }
1731 };
1732 
1733 template <typename SrcOp>
1734 class ReduceConverter : public OpRewritePattern<SrcOp> {
1735 public:
1737 
1738  LogicalResult matchAndRewrite(SrcOp reduceOp,
1739  PatternRewriter &rewriter) const final {
1740  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1741  }
1742 };
1743 
1744 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
1746 
1748  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
1749  ConversionPatternRewriter &rewriter) const override {
1750  auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
1751  auto resultType = op.getType().dyn_cast<RankedTensorType>();
1752 
1753  Location loc = op.getLoc();
1754  int axis = op.getAxis();
1755  Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
1756  loc, rewriter.getIndexAttr(axis));
1757  int rank = resultType.getRank();
1758  SmallVector<Value, 3> offsets, sizes, strides;
1759  sizes.reserve(rank);
1760  strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
1761  offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
1762 
1763  SmallVector<Value> dynDims;
1764  for (int i = 0; i < rank; ++i) {
1765  sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
1766  loc, adaptor.getOperands()[0], i));
1767  if (inputType.isDynamicDim(i)) {
1768  dynDims.push_back(
1769  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
1770  }
1771  }
1772 
1773  Value resultDimSize = sizes[axis];
1774  for (auto arg : adaptor.getOperands().drop_front()) {
1775  auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
1776  resultDimSize =
1777  rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
1778  }
1779  sizes[axis] = resultDimSize;
1780 
1781  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1782  loc, resultType.getShape(), resultType.getElementType(), dynDims);
1783 
1784  auto toOpFoldResult = [](Value v) -> OpFoldResult {
1785  auto op = v.getDefiningOp<arith::ConstantIndexOp>();
1786  if (!op)
1787  return v;
1788  return op.getValue();
1789  };
1790  Value result = emptyTensor;
1791  for (auto arg : adaptor.getOperands()) {
1792  sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
1793  result = rewriter.createOrFold<tensor::InsertSliceOp>(
1794  loc, arg, result,
1795  llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
1796  llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
1797  llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
1798  offsets[axis] =
1799  rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
1800  }
1801  rewriter.replaceOp(op, result);
1802  return success();
1803  }
1804 };
1805 
1806 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1807 public:
1809 
1810  LogicalResult matchAndRewrite(tosa::ReverseOp op,
1811  PatternRewriter &rewriter) const final {
1812  auto loc = op.getLoc();
1813  Value input = op.getInput();
1814  auto inputTy = input.getType().template cast<ShapedType>();
1815  auto resultTy = op.getType().template cast<ShapedType>();
1816  auto axis = op.getAxis();
1817 
1818  SmallVector<Value> dynDims;
1819  for (int i = 0; i < inputTy.getRank(); i++) {
1820  if (inputTy.isDynamicDim(i)) {
1821  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1822  }
1823  }
1824 
1825  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1826 
1827  // First fill the output buffer with the init value.
1828  auto emptyTensor = rewriter
1829  .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1830  inputTy.getElementType(),
1831  ArrayRef<Value>({dynDims}))
1832  .getResult();
1833  SmallVector<AffineMap, 2> affineMaps = {
1834  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1835 
1836  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1837  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
1838  getNParallelLoopsAttrs(resultTy.getRank()),
1839  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1840  llvm::SmallVector<Value> indices;
1841  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1842  auto index =
1843  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1844  if (i == axis) {
1845  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1846  auto sizeMinusOne =
1847  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1848  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1849  index);
1850  }
1851 
1852  indices.push_back(index);
1853  }
1854 
1855  auto extract = nestedBuilder.create<tensor::ExtractOp>(
1856  nestedLoc, input, indices);
1857  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1858  extract.getResult());
1859  });
1860  return success();
1861  }
1862 };
1863 
1864 // This converter translate a tile operation to a reshape, broadcast, reshape.
1865 // The first reshape minimally expands each tiled dimension to include a
1866 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1867 // multiple.
1868 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1870 
1872  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1873  ConversionPatternRewriter &rewriter) const override {
1874  auto loc = op.getLoc();
1875  auto input = op.getInput1();
1876  auto inputTy = input.getType().cast<ShapedType>();
1877  auto inputShape = inputTy.getShape();
1878  auto resultTy = op.getType().cast<ShapedType>();
1879  auto elementTy = inputTy.getElementType();
1880  int64_t rank = inputTy.getRank();
1881 
1882  SmallVector<int64_t> multiples;
1883  getValuesFromIntArrayAttribute(op.getMultiples(), multiples);
1884 
1885  // Broadcast the newly added dimensions to their appropriate multiple.
1886  SmallVector<int64_t, 2> genericShape;
1887  for (int i = 0; i < rank; i++) {
1888  int64_t dim = multiples[i];
1889  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1890  genericShape.push_back(inputShape[i]);
1891  }
1892 
1893  SmallVector<Value> dynDims;
1894  for (int i = 0; i < inputTy.getRank(); i++) {
1895  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1896  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1897  }
1898  }
1899 
1900  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1901  op.getLoc(), genericShape, elementTy, dynDims);
1902 
1903  // We needs to map the input shape to the non-broadcasted dimensions.
1904  SmallVector<AffineExpr, 4> dimExprs;
1905  dimExprs.reserve(rank);
1906  for (unsigned i = 0; i < rank; ++i)
1907  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1908 
1909  auto readAffineMap =
1910  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1911  rewriter.getContext());
1912 
1913  SmallVector<AffineMap, 2> affineMaps = {
1914  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1915 
1916  auto genericOp = rewriter.create<linalg::GenericOp>(
1917  loc, RankedTensorType::get(genericShape, elementTy), input,
1918  ValueRange{emptyTensor}, affineMaps,
1919  getNParallelLoopsAttrs(genericShape.size()),
1920  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1921  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1922  });
1923 
1924  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1925  op, resultTy, genericOp.getResult(0),
1926  rewriter.getI64ArrayAttr(resultTy.getShape()));
1927  return success();
1928  }
1929 };
1930 
1931 class PadConverter : public OpRewritePattern<tosa::PadOp> {
1932 public:
1934 
1935  LogicalResult matchAndRewrite(tosa::PadOp padOp,
1936  PatternRewriter &rewriter) const final {
1937  auto loc = padOp.getLoc();
1938  auto input = padOp.getInput1();
1939  auto padding = padOp.getPadding();
1940 
1941  ShapedType inputTy = input.getType().cast<ShapedType>();
1942  Type elementTy = inputTy.getElementType();
1943  int64_t rank = inputTy.getRank();
1944 
1945  // Setup the default constantAttr.
1946 
1947  Value padConstant;
1948 
1949  if (padOp.getPadConst()) {
1950  padConstant = rewriter.createOrFold<tensor::ExtractOp>(
1951  loc, padOp.getPadConst(), ValueRange({}));
1952  } else {
1953  Attribute constantAttr;
1954  if (elementTy.isa<FloatType>()) {
1955  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
1956  } else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
1957  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
1958  } else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
1959  int64_t value = padOp.getQuantizationInfo()->getInputZp();
1960  constantAttr = rewriter.getIntegerAttr(elementTy, value);
1961  }
1962  if (constantAttr)
1963  padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
1964  }
1965 
1966  if (!padConstant) {
1967  return rewriter.notifyMatchFailure(
1968  padOp, "tosa.pad was unable to determine the pad constant value.");
1969  }
1970 
1971  Value lowIndex =
1972  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
1973  Value highIndex =
1974  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
1975 
1976  SmallVector<OpFoldResult, 3> lowValues;
1977  SmallVector<OpFoldResult, 3> highValues;
1978 
1979  lowValues.reserve(rank);
1980  highValues.reserve(rank);
1981 
1982  for (int i = 0; i < rank; i++) {
1983  Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
1984  Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
1985  loc, padding, ValueRange({inputIndex, lowIndex}));
1986  Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
1987  loc, padding, ValueRange({inputIndex, highIndex}));
1988 
1989  lowVal = rewriter.createOrFold<arith::IndexCastOp>(
1990  loc, rewriter.getIndexType(), lowVal);
1991  highVal = rewriter.createOrFold<arith::IndexCastOp>(
1992  loc, rewriter.getIndexType(), highVal);
1993 
1994  lowValues.push_back(lowVal);
1995  highValues.push_back(highVal);
1996  }
1997 
1998  auto newPadOp = rewriter.create<tensor::PadOp>(
1999  loc, padOp.getType(), input, lowValues, highValues, padConstant);
2000 
2001  rewriter.replaceOp(padOp, newPadOp.getResult());
2002  return success();
2003  }
2004 };
2005 
2006 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2007 // op, producing two output buffers.
2008 //
2009 // The first output buffer contains the index of the found maximum value. It is
2010 // initialized to 0 and is resulting integer type.
2011 //
2012 // The second output buffer contains the maximum value found. It is initialized
2013 // to the minimum representable value of the input element type. After being
2014 // populated by indexed_generic, this buffer is disgarded as only the index is
2015 // requested.
2016 //
2017 // The indexed_generic op updates both the maximum value and index if the
2018 // current value exceeds the running max.
2019 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2020 public:
2022 
2023  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2024  PatternRewriter &rewriter) const final {
2025  auto loc = argmaxOp.getLoc();
2026  Value input = argmaxOp.getInput();
2027  auto inputTy = input.getType().cast<ShapedType>();
2028  auto resultTy = argmaxOp.getOutput().getType().cast<ShapedType>();
2029  auto inElementTy = inputTy.getElementType();
2030  auto outElementTy = resultTy.getElementType();
2031  int axis = argmaxOp.getAxis();
2032  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2033 
2034  if (!outElementTy.isa<IntegerType>())
2035  return rewriter.notifyMatchFailure(
2036  argmaxOp,
2037  "tosa.arg_max to linalg.* requires integer-like result type");
2038 
2039  SmallVector<Value> dynDims;
2040  for (int i = 0; i < inputTy.getRank(); i++) {
2041  if (inputTy.isDynamicDim(i) && i != axis) {
2042  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2043  }
2044  }
2045 
2046  // First fill the output buffer for the index.
2047  auto emptyTensorIdx = rewriter
2048  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2049  outElementTy, dynDims)
2050  .getResult();
2051  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
2052  loc, rewriter.getIntegerAttr(outElementTy, 0));
2053  auto filledTensorIdx =
2054  rewriter
2055  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
2056  ValueRange{emptyTensorIdx})
2057  .result();
2058 
2059  // Second fill the output buffer for the running max.
2060  auto emptyTensorMax = rewriter
2061  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2062  inElementTy, dynDims)
2063  .getResult();
2064  auto fillValueMaxAttr =
2065  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2066 
2067  if (!fillValueMaxAttr)
2068  return rewriter.notifyMatchFailure(
2069  argmaxOp, "unsupported tosa.argmax element type");
2070 
2071  auto fillValueMax =
2072  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2073  auto filledTensorMax =
2074  rewriter
2075  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2076  ValueRange{emptyTensorMax})
2077  .result();
2078 
2079  // We need to reduce along the arg-max axis, with parallel operations along
2080  // the rest.
2082  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2083  iteratorTypes[axis] = utils::IteratorType::reduction;
2084 
2085  SmallVector<AffineExpr, 2> srcExprs;
2086  SmallVector<AffineExpr, 2> dstExprs;
2087  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2088  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2089  if (axis != i)
2090  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2091  }
2092 
2093  bool didEncounterError = false;
2094  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
2095  auto linalgOp = rewriter.create<linalg::GenericOp>(
2096  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2097  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2098  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2099  ValueRange blockArgs) {
2100  auto newValue = blockArgs[0];
2101  auto oldIndex = blockArgs[1];
2102  auto oldValue = blockArgs[2];
2103 
2104  Value newIndex = rewriter.create<arith::IndexCastOp>(
2105  nestedLoc, oldIndex.getType(),
2106  rewriter.create<linalg::IndexOp>(loc, axis));
2107 
2108  Value predicate;
2109  if (inElementTy.isa<FloatType>()) {
2110  predicate = rewriter.create<arith::CmpFOp>(
2111  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2112  } else if (inElementTy.isa<IntegerType>()) {
2113  predicate = rewriter.create<arith::CmpIOp>(
2114  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2115  } else {
2116  didEncounterError = true;
2117  return;
2118  }
2119 
2120  auto resultMax = rewriter.create<arith::SelectOp>(
2121  nestedLoc, predicate, newValue, oldValue);
2122  auto resultIndex = rewriter.create<arith::SelectOp>(
2123  nestedLoc, predicate, newIndex, oldIndex);
2124  nestedBuilder.create<linalg::YieldOp>(
2125  nestedLoc, ValueRange({resultIndex, resultMax}));
2126  });
2127 
2128  if (didEncounterError)
2129  return rewriter.notifyMatchFailure(
2130  argmaxOp, "unsupported tosa.argmax element type");
2131 
2132  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2133  return success();
2134  }
2135 };
2136 
2137 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2138 public:
2141  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2142  ConversionPatternRewriter &rewriter) const final {
2143  auto input = adaptor.getOperands()[0];
2144  auto indices = adaptor.getOperands()[1];
2145 
2146  auto resultTy = op.getType().cast<ShapedType>();
2147 
2148  auto dynamicDimsOr = checkHasDynamicBatchDims(
2149  rewriter, op, {input, indices, op.getOutput()});
2150  if (!dynamicDimsOr.has_value())
2151  return failure();
2152  SmallVector<Value> dynamicDims = dynamicDimsOr.value();
2153 
2154  auto resultElementTy = resultTy.getElementType();
2155 
2156  auto loc = op.getLoc();
2157 
2158  auto emptyTensor =
2159  rewriter
2160  .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2161  dynamicDims)
2162  .getResult();
2163 
2164  SmallVector<AffineMap, 2> affineMaps = {
2166  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2167  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2168  rewriter.getContext()),
2169  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2170 
2171  auto genericOp = rewriter.create<linalg::GenericOp>(
2172  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2173  ValueRange{emptyTensor}, affineMaps,
2174  getNParallelLoopsAttrs(resultTy.getRank()),
2175  [&](OpBuilder &b, Location loc, ValueRange args) {
2176  auto indexValue = args[0];
2177  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2178  Value index1 = rewriter.create<arith::IndexCastOp>(
2179  loc, rewriter.getIndexType(), indexValue);
2180  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2181  Value extract = rewriter.create<tensor::ExtractOp>(
2182  loc, input, ValueRange{index0, index1, index2});
2183  rewriter.create<linalg::YieldOp>(loc, extract);
2184  });
2185  rewriter.replaceOp(op, genericOp.getResult(0));
2186  return success();
2187  }
2188 };
2189 
2190 // Lowerings the TableOp to a series of gathers and numerica operations. This
2191 // includes interpolation between the high/low values. For the I8 varient, this
2192 // simplifies to a single gather operation.
2193 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2194 public:
2196 
2197  LogicalResult matchAndRewrite(tosa::TableOp op,
2198  PatternRewriter &rewriter) const final {
2199  auto loc = op.getLoc();
2200  Value input = op.getInput();
2201  Value table = op.getTable();
2202  auto inputTy = input.getType().cast<ShapedType>();
2203  auto tableTy = table.getType().cast<ShapedType>();
2204  auto resultTy = op.getType().cast<ShapedType>();
2205 
2206  auto inputElementTy = inputTy.getElementType();
2207  auto tableElementTy = tableTy.getElementType();
2208  auto resultElementTy = resultTy.getElementType();
2209 
2210  SmallVector<Value> dynDims;
2211  for (int i = 0; i < resultTy.getRank(); ++i) {
2212  if (inputTy.isDynamicDim(i)) {
2213  dynDims.push_back(
2214  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2215  }
2216  }
2217 
2218  auto emptyTensor = rewriter
2219  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2220  resultElementTy, dynDims)
2221  .getResult();
2222 
2223  SmallVector<AffineMap, 2> affineMaps = {
2224  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2225  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2226 
2227  auto genericOp = rewriter.create<linalg::GenericOp>(
2228  loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2229  getNParallelLoopsAttrs(resultTy.getRank()));
2230  rewriter.replaceOp(op, genericOp.getResult(0));
2231 
2232  {
2233  OpBuilder::InsertionGuard regionGuard(rewriter);
2234  Block *block = rewriter.createBlock(
2235  &genericOp.getRegion(), genericOp.getRegion().end(),
2236  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2237 
2238  auto inputValue = block->getArgument(0);
2239  rewriter.setInsertionPointToStart(block);
2240  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2241  resultElementTy.isInteger(8)) {
2242  Value index = rewriter.create<arith::IndexCastOp>(
2243  loc, rewriter.getIndexType(), inputValue);
2244  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2245  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2246  index, offset);
2247  Value extract =
2248  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2249  rewriter.create<linalg::YieldOp>(loc, extract);
2250  return success();
2251  }
2252 
2253  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2254  resultElementTy.isInteger(32)) {
2255  Value extend = rewriter.create<arith::ExtSIOp>(
2256  loc, rewriter.getI32Type(), inputValue);
2257 
2258  auto offset = rewriter.create<arith::ConstantOp>(
2259  loc, rewriter.getI32IntegerAttr(32768));
2260  auto seven = rewriter.create<arith::ConstantOp>(
2261  loc, rewriter.getI32IntegerAttr(7));
2262  auto one = rewriter.create<arith::ConstantOp>(
2263  loc, rewriter.getI32IntegerAttr(1));
2264  auto b1111111 = rewriter.create<arith::ConstantOp>(
2265  loc, rewriter.getI32IntegerAttr(127));
2266 
2267  // Compute the index and fractional part from the input value:
2268  // value = value + 32768
2269  // index = value >> 7;
2270  // fraction = 0x01111111 & value
2271  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2272  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2273  Value fraction =
2274  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2275 
2276  // Extract the base and next values from the table.
2277  // base = (int32_t) table[index];
2278  // next = (int32_t) table[index + 1];
2279  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2280 
2281  index = rewriter.create<arith::IndexCastOp>(
2282  loc, rewriter.getIndexType(), index);
2283  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2284  loc, rewriter.getIndexType(), indexPlusOne);
2285 
2286  Value base =
2287  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2288  Value next = rewriter.create<tensor::ExtractOp>(
2289  loc, table, ValueRange{indexPlusOne});
2290 
2291  base =
2292  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2293  next =
2294  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2295 
2296  // Use the fractional part to interpolate between the input values:
2297  // result = (base << 7) + (next - base) * fraction
2298  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2299  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2300  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2301  Value result =
2302  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2303 
2304  rewriter.create<linalg::YieldOp>(loc, result);
2305 
2306  return success();
2307  }
2308  }
2309 
2310  return rewriter.notifyMatchFailure(
2311  op, "unable to create body for tosa.table op");
2312  }
2313 };
2314 
2315 } // namespace
2316 
2318  RewritePatternSet *patterns) {
2319 
2320  // We have multiple resize coverters to handle degenerate cases.
2321  patterns->add<GenericResizeConverter>(patterns->getContext(),
2322  /*benefit=*/100);
2323  patterns->add<BroadcastResizeConverter>(patterns->getContext(),
2324  /*benefit=*/200);
2325 
2326  patterns->add<
2327  // clang-format off
2328  PointwiseConverter<tosa::AddOp>,
2329  PointwiseConverter<tosa::SubOp>,
2330  PointwiseConverter<tosa::MulOp>,
2331  PointwiseConverter<tosa::DivOp>,
2332  PointwiseConverter<tosa::NegateOp>,
2333  PointwiseConverter<tosa::PowOp>,
2334  PointwiseConverter<tosa::ReciprocalOp>,
2335  PointwiseConverter<tosa::RsqrtOp>,
2336  PointwiseConverter<tosa::LogOp>,
2337  PointwiseConverter<tosa::ExpOp>,
2338  PointwiseConverter<tosa::AbsOp>,
2339  PointwiseConverter<tosa::TanhOp>,
2340  PointwiseConverter<tosa::BitwiseAndOp>,
2341  PointwiseConverter<tosa::BitwiseOrOp>,
2342  PointwiseConverter<tosa::BitwiseNotOp>,
2343  PointwiseConverter<tosa::BitwiseXorOp>,
2344  PointwiseConverter<tosa::LogicalAndOp>,
2345  PointwiseConverter<tosa::LogicalNotOp>,
2346  PointwiseConverter<tosa::LogicalOrOp>,
2347  PointwiseConverter<tosa::LogicalXorOp>,
2348  PointwiseConverter<tosa::CastOp>,
2349  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2350  PointwiseConverter<tosa::LogicalRightShiftOp>,
2351  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2352  PointwiseConverter<tosa::ClzOp>,
2353  PointwiseConverter<tosa::SelectOp>,
2354  PointwiseConverter<tosa::GreaterOp>,
2355  PointwiseConverter<tosa::GreaterEqualOp>,
2356  PointwiseConverter<tosa::EqualOp>,
2357  PointwiseConverter<tosa::MaximumOp>,
2358  PointwiseConverter<tosa::MinimumOp>,
2359  PointwiseConverter<tosa::CeilOp>,
2360  PointwiseConverter<tosa::FloorOp>,
2361  PointwiseConverter<tosa::ClampOp>,
2362  PointwiseConverter<tosa::SigmoidOp>,
2363  IdentityNConverter<tosa::IdentityOp>,
2364  ReduceConverter<tosa::ReduceAllOp>,
2365  ReduceConverter<tosa::ReduceAnyOp>,
2366  ReduceConverter<tosa::ReduceMinOp>,
2367  ReduceConverter<tosa::ReduceMaxOp>,
2368  ReduceConverter<tosa::ReduceSumOp>,
2369  ReduceConverter<tosa::ReduceProdOp>,
2370  ArgMaxConverter,
2371  ConcatConverter,
2372  GatherConverter,
2373  PadConverter,
2374  ReshapeConverterCollapse,
2375  ReshapeConverterExpand,
2376  ReshapeConverterCollapseExpand,
2377  RescaleConverter,
2378  ReverseConverter,
2379  TableConverter,
2380  TileConverter,
2381  TransposeConverter>(patterns->getContext());
2382  // clang-format on
2383 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static constexpr const bool value
static int resultIndex(int i)
Definition: Operator.cpp:308
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:696
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, PatternRewriter &rewriter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static bool findIntermediateShape(ArrayRef< int64_t > lhsShape, ArrayRef< int64_t > rhsShape, SmallVector< int64_t > &intermediateShape, bool isDynamic)
static bool createReassociationMapsForCollapse(PatternRewriter &rewriter, ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape, SmallVector< ReassociationExprs, 4 > &reassociationMap, bool isDynamic)
static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:256
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:236
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:137
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:190
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:350
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:235
IntegerType getI32Type()
Definition: Builders.cpp:68
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:72
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:101
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:327
MLIRContext * getContext() const
Definition: Builders.h:54
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:306
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:262
IndexType getIndexType()
Definition: Builders.cpp:56
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:227
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
U cast() const
Definition: Location.h:90
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:472
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:371
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
unsigned getNumOperands()
Definition: Operation.h:263
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
result_range getResults()
Definition: Operation.h:332
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:517
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:33
U dyn_cast() const
Definition: Types.h:270
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:63
bool isa() const
Definition: Types.h:260
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:93
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
void getValuesFromIntArrayAttribute(ArrayAttr attr, SmallVector< T > &arrayValues)
Optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value clampIntHelper(Location loc, Value arg, arith::ConstantOp min, arith::ConstantOp max, OpBuilder &rewriter)
Value clampFloatHelper(Location loc, Value arg, arith::ConstantOp min, arith::ConstantOp max, OpBuilder &rewriter)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:329
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:255
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:488
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356