MLIR  15.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
27 
28 #include <numeric>
29 
30 using namespace mlir;
31 using namespace mlir::tosa;
32 
33 template <typename T>
34 static arith::ConstantOp
35 createConstFromIntAttribute(Operation *op, const std::string &attrName,
36  Type requiredAttrType, OpBuilder &rewriter) {
37  auto castedN = static_cast<T>(
38  op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
39  return rewriter.create<arith::ConstantOp>(
40  op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
41 }
42 
43 static Value
45  ArrayRef<Type> resultTypes,
46  PatternRewriter &rewriter) {
47  Location loc = op->getLoc();
48  auto elementTy =
49  op->getOperand(0).getType().cast<ShapedType>().getElementType();
50 
51  // tosa::AbsOp
52  if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
53  return rewriter.create<math::AbsOp>(loc, resultTypes, args);
54 
55  if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
56  auto zero = rewriter.create<arith::ConstantOp>(
57  loc, rewriter.getZeroAttr(elementTy));
58  auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
59  args[0], zero);
60  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
61  return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
62  }
63 
64  // tosa::AddOp
65  if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
66  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
67 
68  if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
69  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
70 
71  // tosa::SubOp
72  if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
73  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
74 
75  if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
76  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
77 
78  // tosa::MulOp
79  if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
80  if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
81  (void)rewriter.notifyMatchFailure(op,
82  "Cannot have shift value for float");
83  return nullptr;
84  }
85  return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
86  }
87 
88  // tosa::DivOp
89  if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
90  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
91 
92  // tosa::ReciprocalOp
93  if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
94  auto one =
95  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
96  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
97  }
98 
99  if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
100  Value a = args[0];
101  Value b = args[1];
102  auto shift =
103  op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
104  if (shift > 0) {
105  auto shiftConst =
106  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
107  if (!a.getType().isInteger(32))
108  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
109 
110  if (!b.getType().isInteger(32))
111  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
112 
113  auto result = rewriter.create<tosa::ApplyScaleOp>(
114  loc, rewriter.getI32Type(), a, b, shiftConst,
115  rewriter.getBoolAttr(false));
116 
117  if (elementTy.isInteger(32))
118  return result;
119 
120  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
121  }
122 
123  int aWidth = a.getType().getIntOrFloatBitWidth();
124  int bWidth = b.getType().getIntOrFloatBitWidth();
125  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
126 
127  if (aWidth < cWidth)
128  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
129  if (bWidth < cWidth)
130  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
131 
132  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
133  }
134 
135  // tosa::NegateOp
136  if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
137  return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
138 
139  if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
140  !cast<tosa::NegateOp>(op).quantization_info()) {
141  auto constant =
142  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
143  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
144  }
145 
146  if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
147  cast<tosa::NegateOp>(op).quantization_info()) {
148  auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
149  int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
150  int64_t inZp = quantizationInfo.getValue().getInputZp();
151  int64_t outZp = quantizationInfo.getValue().getOutputZp();
152 
153  // Compute the maximum value that can occur in the intermediate buffer.
154  int64_t zpAdd = inZp + outZp;
155  int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
156  std::abs(zpAdd) + 1;
157 
158  // Convert that maximum value into the maximum bitwidth needed to represent
159  // it. We assume 48-bit numbers may be supported further in the pipeline.
160  int intermediateBitWidth = 64;
161  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
162  intermediateBitWidth = 16;
163  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
164  intermediateBitWidth = 32;
165  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
166  intermediateBitWidth = 48;
167  }
168 
169  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
170  Value zpAddValue = rewriter.create<arith::ConstantOp>(
171  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
172 
173  // The negation can be applied by doing:
174  // outputValue = inZp + outZp - inputValue
175  auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
176  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
177 
178  // Clamp to the negation range.
179  auto min = rewriter.create<arith::ConstantIntOp>(
180  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
181  intermediateType);
182  auto max = rewriter.create<arith::ConstantIntOp>(
183  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
184  intermediateType);
185  auto clamp = clampHelper<arith::CmpIOp>(
186  loc, sub, min, max, arith::CmpIPredicate::slt, rewriter);
187 
188  // Truncate to the final value.
189  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
190  }
191 
192  // tosa::BitwiseAndOp
193  if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
194  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
195 
196  // tosa::BitwiseOrOp
197  if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
198  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
199 
200  // tosa::BitwiseNotOp
201  if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
202  auto allOnesAttr = rewriter.getIntegerAttr(
203  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
204  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
205  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
206  }
207 
208  // tosa::BitwiseXOrOp
209  if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
210  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
211 
212  // tosa::LogicalLeftShiftOp
213  if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
214  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
215 
216  // tosa::LogicalRightShiftOp
217  if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
218  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
219 
220  // tosa::ArithmeticRightShiftOp
221  if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
222  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
223  auto round = op->getAttr("round").cast<BoolAttr>().getValue();
224  if (!round) {
225  return result;
226  }
227 
228  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
229  auto one =
230  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
231  auto zero =
232  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
233  auto i1one =
234  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
235 
236  // Checking that input2 != 0
237  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
238  loc, arith::CmpIPredicate::sgt, args[1], zero);
239 
240  // Checking for the last bit of input1 to be 1
241  auto subtract =
242  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
243  auto shifted =
244  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
245  ->getResults();
246  auto truncated =
247  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, mlir::None);
248  auto isInputOdd =
249  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
250 
251  auto shouldRound = rewriter.create<arith::AndIOp>(
252  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
253  auto extended =
254  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
255  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
256  }
257 
258  // tosa::ClzOp
259  if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
260  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
261  }
262 
263  // tosa::LogicalAnd
264  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
265  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
266 
267  // tosa::LogicalNot
268  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
269  auto one = rewriter.create<arith::ConstantOp>(
270  loc, rewriter.getIntegerAttr(elementTy, 1));
271  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
272  }
273 
274  // tosa::LogicalOr
275  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
276  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
277 
278  // tosa::LogicalXor
279  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
280  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
281 
282  // tosa::PowOp
283  if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
284  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
285 
286  // tosa::RsqrtOp
287  if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
288  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
289 
290  // tosa::LogOp
291  if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
292  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
293 
294  // tosa::ExpOp
295  if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
296  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
297 
298  // tosa::TanhOp
299  if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
300  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
301 
302  // tosa::GreaterOp
303  if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
304  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
305  args[0], args[1]);
306 
307  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
308  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
309  args[0], args[1]);
310 
311  // tosa::GreaterEqualOp
312  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
313  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
314  args[0], args[1]);
315 
316  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
317  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
318  args[0], args[1]);
319 
320  // tosa::EqualOp
321  if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
322  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
323  args[0], args[1]);
324 
325  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
326  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
327  args[0], args[1]);
328 
329  // tosa::SelectOp
330  if (isa<tosa::SelectOp>(op)) {
331  elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
332  if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
333  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
334  }
335 
336  // tosa::MaximumOp
337  if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
338  auto predicate = rewriter.create<arith::CmpFOp>(
339  loc, arith::CmpFPredicate::OGT, args[0], args[1]);
340  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
341  }
342 
343  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
344  auto predicate = rewriter.create<arith::CmpIOp>(
345  loc, arith::CmpIPredicate::sgt, args[0], args[1]);
346  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
347  }
348 
349  // tosa::MinimumOp
350  if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
351  auto predicate = rewriter.create<arith::CmpFOp>(
352  loc, arith::CmpFPredicate::OLT, args[0], args[1]);
353  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
354  }
355 
356  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
357  auto predicate = rewriter.create<arith::CmpIOp>(
358  loc, arith::CmpIPredicate::slt, args[0], args[1]);
359  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
360  }
361 
362  // tosa::CeilOp
363  if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
364  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
365 
366  // tosa::FloorOp
367  if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
368  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
369 
370  // tosa::ClampOp
371  if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
372  auto min = rewriter.create<arith::ConstantOp>(loc, elementTy,
373  op->getAttr("min_fp"));
374  auto max = rewriter.create<arith::ConstantOp>(loc, elementTy,
375  op->getAttr("max_fp"));
376  return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
377  arith::CmpFPredicate::OLT, rewriter);
378  }
379 
380  if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
381  auto intTy = elementTy.cast<IntegerType>();
382  int32_t min = static_cast<int32_t>(
383  op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
384  int32_t max = static_cast<int32_t>(
385  op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
386 
387  if (intTy.isUnsignedInteger()) {
388  min = std::max<int32_t>(min, 0);
389  max = std::min<int32_t>(
390  max,
391  APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
392  } else {
393  min = std::max<int32_t>(
394  min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
395  .getSExtValue());
396  max = std::min<int32_t>(
397  max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
398  .getSExtValue());
399  }
400 
401  auto minVal = rewriter.create<arith::ConstantIntOp>(
402  loc, min, intTy.getIntOrFloatBitWidth());
403  auto maxVal = rewriter.create<arith::ConstantIntOp>(
404  loc, max, intTy.getIntOrFloatBitWidth());
405  return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal,
406  arith::CmpIPredicate::slt, rewriter);
407  }
408 
409  // tosa::ReluNOp
410  if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
411  auto zero =
412  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
413  auto n = rewriter.create<arith::ConstantOp>(loc, elementTy,
414  op->getAttr("max_fp"));
415  return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
416  arith::CmpFPredicate::OLT, rewriter);
417  }
418 
419  if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
420  auto zero =
421  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
422  auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
423  rewriter);
424  return clampHelper<arith::CmpIOp>(loc, args[0], zero, n,
425  arith::CmpIPredicate::slt, rewriter);
426  }
427 
428  // tosa::SigmoidOp
429  if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
430  auto one =
431  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
432  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
433  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
434  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
435  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
436  }
437 
438  // tosa::CastOp
439  if (isa<tosa::CastOp>(op)) {
440  Type srcTy = elementTy;
441  Type dstTy = resultTypes.front();
442  bool bitExtend =
444 
445  if (srcTy == dstTy)
446  return args.front();
447 
448  if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
449  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args, mlir::None);
450 
451  if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
452  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
453  mlir::None);
454 
455  // 1-bit integers need to be treated as signless.
456  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
457  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
458  mlir::None);
459 
460  if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
461  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
462  mlir::None);
463 
464  // Unsigned integers need an unrealized cast so that they can be passed
465  // to UIToFP.
466  if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) {
467  auto unrealizedCast =
468  rewriter
469  .create<UnrealizedConversionCastOp>(
470  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
471  args[0])
472  .getResult(0);
473  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
474  unrealizedCast);
475  }
476 
477  // All other si-to-fp conversions should be handled by SIToFP.
478  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
479  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
480  mlir::None);
481 
482  // Casting to boolean, floats need to only be checked as not-equal to zero.
483  if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
484  Value zero = rewriter.create<arith::ConstantOp>(
485  loc, rewriter.getFloatAttr(srcTy, 0.0));
486  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
487  args.front(), zero);
488  }
489 
490  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
491  auto zero = rewriter.create<arith::ConstantOp>(
492  loc, rewriter.getF32FloatAttr(0.0f));
493  auto half = rewriter.create<arith::ConstantOp>(
494  loc, rewriter.getF32FloatAttr(0.5f));
495 
496  auto intMin = rewriter.create<arith::ConstantOp>(
497  loc, rewriter.getF32FloatAttr(
498  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
499  .getSExtValue()));
500 
501  auto intMax = rewriter.create<arith::ConstantOp>(
502  loc, rewriter.getF32FloatAttr(
503  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
504  .getSExtValue()));
505 
506  auto added = rewriter.create<arith::AddFOp>(loc, args[0], half);
507  auto subbed = rewriter.create<arith::SubFOp>(loc, args[0], half);
508  auto negative = rewriter.create<arith::CmpFOp>(
509  loc, arith::CmpFPredicate::OLT, args[0], zero);
510  auto rounded =
511  rewriter.create<arith::SelectOp>(loc, negative, subbed, added);
512 
513  auto clamped = clampHelper<arith::CmpFOp>(
514  loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter);
515 
516  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
517  }
518 
519  // Casting to boolean, integers need to only be checked as not-equal to
520  // zero.
521  if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
522  Value zero = rewriter.create<arith::ConstantIntOp>(
523  loc, 0, srcTy.getIntOrFloatBitWidth());
524  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
525  args.front(), zero);
526  }
527 
528  if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
529  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
530  mlir::None);
531 
532  if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
533  auto intMin = rewriter.create<arith::ConstantIntOp>(
534  loc,
535  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
536  .getSExtValue(),
537  srcTy.getIntOrFloatBitWidth());
538 
539  auto intMax = rewriter.create<arith::ConstantIntOp>(
540  loc,
541  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
542  .getSExtValue(),
543  srcTy.getIntOrFloatBitWidth());
544 
545  auto clamped = clampHelper<arith::CmpIOp>(
546  loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter);
547  return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped);
548  }
549  }
550 
551  (void)rewriter.notifyMatchFailure(
552  op, "unhandled op for linalg body calculation for elementwise op");
553  return nullptr;
554 }
555 
556 static LogicalResult
558  PatternRewriter &rewriter) {
559  auto loc = operation->getLoc();
560 
561  assert(operation->getNumResults() == 1 &&
562  "All TOSA elementwise ops should only return a single result.");
563 
564  auto results = operation->getResults();
565  auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
566 
567  if (!resultTy)
568  return rewriter.notifyMatchFailure(operation,
569  "All results must be a shaped type");
570 
571  unsigned rank = resultTy.getRank();
572 
573  // Construct the indexing maps needed for linalg.generic ops.
574  SmallVector<Type> bodyArgTypes;
575 
576  for (Value in : operation->getOperands())
577  bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
578 
579  SmallVector<Type> opResultTypes;
580  SmallVector<Value> initTensors;
581 
582  SmallVector<Value> dynDims;
583  dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
584 
585  for (auto arg : operation->getOperands()) {
586  auto operandTy = arg.getType().cast<ShapedType>();
587  for (int i = 0; i < operandTy.getRank(); i++) {
588  if (operandTy.isDynamicDim(i) && !dynDims[i])
589  dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
590  }
591  }
592 
593  SmallVector<Value> filteredDims = condenseValues(dynDims);
594 
595  for (auto result : results) {
596  auto resultTy = result.getType().template cast<ShapedType>();
597  initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
598  loc, filteredDims, resultTy.getShape(), resultTy.getElementType()));
599  opResultTypes.push_back(result.getType());
600  }
601 
602  auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
603  initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
604 
605  SmallVector<Value, 2> operands;
606  SmallVector<AffineMap, 2> indexingMaps;
607  indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
608 
609  // Input indexing maps may be broadcasted.
610  for (Value operand : operation->getOperands()) {
611  ShapedType type = operand.getType().cast<ShapedType>();
612 
613  if (type.getShape() == resultTy.getShape()) {
614  operands.push_back(operand);
615  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
616  continue;
617  }
618 
619  SmallVector<int64_t, 5> newShape;
620  SmallVector<AffineExpr, 4> affineExprs;
621  newShape.reserve(type.getRank());
622  for (const auto &it : llvm::enumerate(type.getShape())) {
623  if (it.value() == resultTy.getDimSize(it.index())) {
624  newShape.push_back(it.value());
625  affineExprs.push_back(
626  mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
627  }
628  }
629 
630  if (newShape.size() != rank) {
631  operand = rewriter.create<tosa::ReshapeOp>(
632  loc, RankedTensorType::get(newShape, type.getElementType()), operand,
633  rewriter.getI64ArrayAttr(newShape));
634  }
635 
636  operands.push_back(operand);
637  indexingMaps.push_back(AffineMap::get(
638  /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
639  rewriter.getContext()));
640  }
641 
642  indexingMaps.append(operation->getNumResults(),
643  rewriter.getMultiDimIdentityMap(rank));
644 
645  bool didEncounterError = false;
646  auto linalgOp = rewriter.create<linalg::GenericOp>(
647  loc, opResultTypes, operands, initTensors, indexingMaps,
649  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
651  operation, blockArgs.take_front(operation->getNumOperands()),
652  bodyResultTypes, rewriter);
653  if (!opResult) {
654  didEncounterError = true;
655  return;
656  }
657  nestedBuilder.create<linalg::YieldOp>(loc, opResult);
658  });
659 
660  if (didEncounterError)
661  return failure();
662 
663  rewriter.replaceOp(operation, linalgOp->getResults());
664  return success();
665 }
666 
667 // Returns the constant initial value for a given reduction operation. The
668 // attribute type varies depending on the element type required.
670  PatternRewriter &rewriter) {
671  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
672  return rewriter.getFloatAttr(elementTy, 0.0);
673 
674  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
675  return rewriter.getIntegerAttr(elementTy, 0);
676 
677  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
678  return rewriter.getFloatAttr(elementTy, 1.0);
679 
680  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
681  return rewriter.getIntegerAttr(elementTy, 1);
682 
683  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
684  return rewriter.getFloatAttr(
685  elementTy, APFloat::getLargest(
686  elementTy.cast<FloatType>().getFloatSemantics(), false));
687 
688  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
689  return rewriter.getIntegerAttr(
690  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
691 
692  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
693  return rewriter.getFloatAttr(
694  elementTy, APFloat::getLargest(
695  elementTy.cast<FloatType>().getFloatSemantics(), true));
696 
697  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
698  return rewriter.getIntegerAttr(
699  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
700 
701  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
702  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
703 
704  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
705  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
706 
707  if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
708  return rewriter.getFloatAttr(
709  elementTy, APFloat::getLargest(
710  elementTy.cast<FloatType>().getFloatSemantics(), true));
711 
712  if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
713  return rewriter.getIntegerAttr(
714  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
715 
716  return {};
717 }
718 
719 // Creates the body calculation for a reduction. The operations vary depending
720 // on the input type.
722  ValueRange args,
723  Type elementTy,
724  PatternRewriter &rewriter) {
725  Location loc = op->getLoc();
726  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
727  return rewriter.create<arith::AddFOp>(loc, args);
728  }
729 
730  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
731  return rewriter.create<arith::AddIOp>(loc, args);
732  }
733 
734  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
735  return rewriter.create<arith::MulFOp>(loc, args);
736  }
737 
738  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
739  return rewriter.create<arith::MulIOp>(loc, args);
740  }
741 
742  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
743  auto predicate = rewriter.create<arith::CmpFOp>(
744  loc, arith::CmpFPredicate::OLT, args[0], args[1]);
745  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
746  }
747 
748  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
749  auto predicate = rewriter.create<arith::CmpIOp>(
750  loc, arith::CmpIPredicate::slt, args[0], args[1]);
751  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
752  }
753 
754  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
755  auto predicate = rewriter.create<arith::CmpFOp>(
756  loc, arith::CmpFPredicate::OGT, args[0], args[1]);
757  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
758  }
759 
760  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
761  auto predicate = rewriter.create<arith::CmpIOp>(
762  loc, arith::CmpIPredicate::sgt, args[0], args[1]);
763  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
764  }
765 
766  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
767  return rewriter.create<arith::AndIOp>(loc, args);
768 
769  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
770  return rewriter.create<arith::OrIOp>(loc, args);
771 
772  return {};
773 }
774 
775 // Performs the match and rewrite for reduction operations. This includes
776 // declaring a correctly sized initial value, and the linalg.generic operation
777 // that reduces across the specified axis.
779  PatternRewriter &rewriter) {
780  auto loc = op->getLoc();
781  auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
782  auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
783  auto elementTy = resultTy.getElementType();
784  Value input = op->getOperand(0);
785 
786  llvm::SmallVector<int64_t> reduceShape;
787  SmallVector<Value> dynDims;
788  for (unsigned i = 0; i < inputTy.getRank(); i++) {
789  if (axis != i) {
790  reduceShape.push_back(inputTy.getDimSize(i));
791  if (inputTy.isDynamicDim(i))
792  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
793  }
794  }
795 
796  Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
797 
798  // First fill the output buffer with the init value.
799  auto initTensor = rewriter
800  .create<linalg::InitTensorOp>(loc, dynDims, reduceShape,
801  resultTy.getElementType())
802  .result();
803 
804  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
805  if (!fillValueAttr)
806  return rewriter.notifyMatchFailure(
807  op, "No initial value found for reduction operation");
808 
809  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
810  auto filledTensor = rewriter
811  .create<linalg::FillOp>(loc, ValueRange{fillValue},
812  ValueRange{initTensor})
813  .result();
814 
817  SmallVector<StringRef, 4> iteratorTypes;
818  for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
819  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
820 
821  iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
823  if (axis != i)
824  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
825  }
826 
827  bool didEncounterError = false;
828  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
829  auto linalgOp = rewriter.create<linalg::GenericOp>(
830  loc, reduceTy, input, filledTensor, maps, iteratorTypes,
831  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
833  op, blockArgs, elementTy, rewriter);
834  if (result)
835  didEncounterError = true;
836 
837  nestedBuilder.create<linalg::YieldOp>(loc, result);
838  });
839 
840  if (!didEncounterError)
841  return failure();
842 
843  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
844  linalgOp.getResults());
845  return success();
846 }
847 
849  ArrayRef<int64_t> rhsShape,
850  SmallVector<int64_t> &intermediateShape,
851  bool isDynamic) {
852  if (isDynamic) {
853  // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
854  intermediateShape = {-1};
855  return true;
856  }
857 
858  if (lhsShape.empty() || rhsShape.empty()) {
859  intermediateShape = {};
860  return true;
861  }
862 
863  unsigned currLhsDim = 0, currRhsDim = 0;
864  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
865  int64_t rhsSize = rhsShape[currRhsDim];
866  int64_t lhsSize = lhsShape[currLhsDim];
867  while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
868  currRhsDim < rhsShape.size()) {
869  if (lhsSize < rhsSize) {
870  currLhsDim++;
871  lhsSize *= lhsShape[currLhsDim];
872  } else {
873  currRhsDim++;
874  rhsSize *= rhsShape[currRhsDim];
875  }
876  }
877  if (lhsSize == rhsSize) {
878  intermediateShape.push_back(lhsSize);
879  }
880  currRhsDim++;
881  currLhsDim++;
882  }
883 
884  // If the iterators didn't reach the end and their leftover dimensions are not
885  // equal to 1 an intermediate shape was not found.
886  while (currLhsDim < lhsShape.size()) {
887  if (lhsShape[currLhsDim++] != 1) {
888  return false;
889  }
890  }
891 
892  while (currRhsDim < rhsShape.size()) {
893  if (rhsShape[currRhsDim++] != 1) {
894  return false;
895  }
896  }
897 
898  return true;
899 }
900 
902  PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
903  ArrayRef<int64_t> dstShape,
904  SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
905 
906  // If the shape is dynamic, create a map for collapsing into one dimension.
907  if (isDynamic) {
909  for (int i = 0, s = srcShape.size(); i < s; ++i)
910  exprs.push_back(rewriter.getAffineDimExpr(i));
911  reassociationMap = {exprs};
912  return true;
913  }
914 
915  if (dstShape.empty()) {
916  reassociationMap = {};
917  return true;
918  }
919 
920  reassociationMap.resize(dstShape.size());
921  unsigned currSrcDim = 0, currDstDim = 0;
922  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
923  int64_t dstSize = dstShape[currDstDim];
924  int64_t srcSize = srcShape[currSrcDim];
925  while (srcSize < dstSize && currSrcDim < srcShape.size()) {
926  reassociationMap[currDstDim].push_back(
927  rewriter.getAffineDimExpr(currSrcDim++));
928  srcSize *= srcShape[currSrcDim];
929  }
930  if (srcSize == dstSize) {
931  reassociationMap[currDstDim].push_back(
932  rewriter.getAffineDimExpr(currSrcDim++));
933  // If the next dim in collapsedShape is not 1, treat subsequent dims in
934  // expandedShape which are 1 to be collapsed.
935  if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
936  while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
937  reassociationMap[currDstDim].push_back(
938  rewriter.getAffineDimExpr(currSrcDim++));
939  }
940  }
941  }
942  currDstDim++;
943  }
944 
945  // If both iterators didn't reach the end, we have leftover dimentions which
946  // implies that we have a mismatch in shape.
947  return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
948 }
949 
950 namespace {
951 
952 template <typename SrcOp>
953 class PointwiseConverter : public OpRewritePattern<SrcOp> {
954 public:
956 
957  LogicalResult matchAndRewrite(SrcOp op,
958  PatternRewriter &rewriter) const final {
959  return elementwiseMatchAndRewriteHelper(op, rewriter);
960  }
961 };
962 
963 class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
964 public:
966 
968  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
969  ConversionPatternRewriter &rewriter) const final {
970  ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
971  ShapedType resultTy = reshape.getType().template cast<ShapedType>();
972  bool isDynamic = !operandTy.hasStaticShape();
973 
974  if (isDynamic && resultTy.getRank() != 1) {
975  return rewriter.notifyMatchFailure(
976  reshape, "Cannot collapse dynamic dims to more than one dimension");
977  }
978 
979  if (operandTy == resultTy) {
980  rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
981  return success();
982  }
983 
984  SmallVector<ReassociationExprs, 4> reassociationMap;
985  if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
986  resultTy.getShape(),
987  reassociationMap, isDynamic)) {
988  return rewriter.notifyMatchFailure(
989  reshape,
990  "tosa.reshape Attempting to collapse into an incompatible shape");
991  }
992 
993  SmallVector<int64_t> intermediateShape;
994  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
995  intermediateShape, isDynamic)) {
996  return rewriter.notifyMatchFailure(
997  reshape, "tosa.reshape Cannot collapse into given shape");
998  }
999 
1000  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1001  reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1002  return success();
1003  }
1004 };
1005 
1006 class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
1007 public:
1009 
1011  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1012  ConversionPatternRewriter &rewriter) const final {
1013  ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
1014  ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1015  bool isDynamic = !operandTy.hasStaticShape();
1016 
1017  if (operandTy == resultTy) {
1018  rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1019  return success();
1020  }
1021 
1022  if (isDynamic && operandTy.getRank() != 1) {
1023  return rewriter.notifyMatchFailure(
1024  reshape, "Cannot expand dynamic dims from more than one dimension");
1025  }
1026 
1027  SmallVector<ReassociationExprs, 4> reassociationMap;
1028  if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
1029  operandTy.getShape(),
1030  reassociationMap, isDynamic)) {
1031  return rewriter.notifyMatchFailure(
1032  reshape,
1033  "tosa.reshape Attempting to expand into an incompatible shape");
1034  }
1035 
1036  SmallVector<int64_t> intermediateShape;
1037  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
1038  intermediateShape, isDynamic) ||
1039  intermediateShape != operandTy.getShape()) {
1040  return rewriter.notifyMatchFailure(
1041  reshape, "tosa.reshape Cannot expand into given shape");
1042  }
1043  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1044  reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1045  return success();
1046  }
1047 };
1048 
1049 class ReshapeConverterCollapseExpand
1050  : public OpConversionPattern<tosa::ReshapeOp> {
1051 public:
1053 
1055  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1056  ConversionPatternRewriter &rewriter) const final {
1057  ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
1058  ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1059  bool isDynamic = !operandTy.hasStaticShape();
1060 
1061  if (operandTy == resultTy) {
1062  rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1063  return success();
1064  }
1065 
1066  SmallVector<int64_t> intermediateShape;
1067  if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
1068  intermediateShape, isDynamic)) {
1069  return rewriter.notifyMatchFailure(
1070  reshape, "tosa.reshape Cannot identify an intermediate shape between "
1071  "the given two shapes");
1072  }
1073 
1074  Value collapse = rewriter.create<tosa::ReshapeOp>(
1075  reshape.getLoc(),
1076  RankedTensorType::get(intermediateShape,
1077  reshape.getType().getElementType()),
1078  adaptor.input1());
1079  Value expand =
1080  rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
1081  rewriter.replaceOp(reshape, expand);
1082 
1083  return success();
1084  }
1085 };
1086 
1087 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1088 public:
1090 
1091  LogicalResult matchAndRewrite(tosa::TransposeOp op,
1092  PatternRewriter &rewriter) const final {
1093  DenseIntElementsAttr perms;
1094  if (!matchPattern(op.perms(), m_Constant(&perms))) {
1095  return failure();
1096  }
1097 
1098  auto loc = op.getLoc();
1099  auto input = op->getOperand(0);
1100  auto resultTy = op.getType().cast<ShapedType>();
1101 
1102  SmallVector<Value> dynDims;
1103  dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1104 
1105  SmallVector<AffineExpr, 2> inputExprs;
1106  inputExprs.resize(resultTy.getRank());
1107  auto operandTy = input.getType().cast<ShapedType>();
1108  for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
1109  auto index = permutation.index();
1110  auto value = permutation.value().getZExtValue();
1111  if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1112  dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
1113  }
1114  inputExprs[value] = rewriter.getAffineDimExpr(index);
1115  }
1116 
1117  SmallVector<Value> filteredDims = condenseValues(dynDims);
1118 
1119  auto initTensor = rewriter.create<linalg::InitTensorOp>(
1120  loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
1121 
1122  SmallVector<AffineMap, 2> affineMaps = {
1123  AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1124  rewriter.getContext()),
1125  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1126 
1127  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1128  op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
1129  getNParallelLoopsAttrs(resultTy.getRank()),
1130  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1131  nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
1132  });
1133  return success();
1134  }
1135 };
1136 
1137 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1138 public:
1140 
1141  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1142  PatternRewriter &rewriter) const final {
1143  auto loc = op.getLoc();
1144  auto input = op.input();
1145  auto inputTy = op.input().getType().cast<ShapedType>();
1146  auto outputTy = op.output().getType().cast<ShapedType>();
1147  unsigned rank = inputTy.getRank();
1148 
1149  // This is an illegal configuration. terminate and log an error
1150  if (op.double_round() && !op.scale32())
1151  return rewriter.notifyMatchFailure(
1152  op, "tosa.rescale requires scale32 for double_round to be true");
1153 
1154  auto dynamicDimsOr =
1155  checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
1156  if (!dynamicDimsOr.hasValue())
1157  return failure();
1158  SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
1159 
1160  // The shift and multiplier values.
1161  SmallVector<int32_t> multiplierValues;
1162  getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
1163 
1164  SmallVector<int8_t> shiftValues;
1165  getValuesFromIntArrayAttribute(op.shift(), shiftValues);
1166 
1167  // If we shift by more than the bitwidth, this just sets to 0.
1168  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1169  if (shiftValues[i] > 63) {
1170  shiftValues[i] = 0;
1171  multiplierValues[i] = 0;
1172  }
1173  }
1174 
1175  // Double round only occurs if shift is greater than 31, check that this
1176  // is ever true.
1177  bool doubleRound =
1178  op.double_round() &&
1179  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1180 
1181  SmallVector<AffineMap> indexingMaps = {
1182  rewriter.getMultiDimIdentityMap(rank)};
1183  SmallVector<Value, 4> genericInputs = {input};
1184 
1185  // If we are rescaling per-channel then we need to store the multiplier
1186  // values in a buffer.
1187  Value multiplierConstant;
1188  int64_t multiplierArg = 0;
1189  if (multiplierValues.size() == 1) {
1190  multiplierConstant = rewriter.create<arith::ConstantOp>(
1191  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1192  } else {
1193  SmallVector<AffineExpr, 2> multiplierExprs{
1194  rewriter.getAffineDimExpr(rank - 1)};
1195  auto multiplierType =
1196  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1197  rewriter.getI32Type());
1198  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1199  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1200 
1201  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1202  /*symbolCount=*/0, multiplierExprs,
1203  rewriter.getContext()));
1204 
1205  multiplierArg = indexingMaps.size() - 1;
1206  }
1207 
1208  // If we are rescaling per-channel then we need to store the shift
1209  // values in a buffer.
1210  Value shiftConstant;
1211  int64_t shiftArg = 0;
1212  if (shiftValues.size() == 1) {
1213  shiftConstant = rewriter.create<arith::ConstantOp>(
1214  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1215  } else {
1216  SmallVector<AffineExpr, 2> shiftExprs = {
1217  rewriter.getAffineDimExpr(rank - 1)};
1218  auto shiftType =
1219  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1220  rewriter.getIntegerType(8));
1221  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1222  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1223  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1224  /*symbolCount=*/0, shiftExprs,
1225  rewriter.getContext()));
1226  shiftArg = indexingMaps.size() - 1;
1227  }
1228 
1229  // Indexing maps for output values.
1230  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1231 
1232  // Construct the indexing maps needed for linalg.generic ops.
1233  Value initTensor = rewriter.create<linalg::InitTensorOp>(
1234  loc, dynamicDims, outputTy.getShape(), outputTy.getElementType());
1235 
1236  auto linalgOp = rewriter.create<linalg::GenericOp>(
1237  loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
1238  getNParallelLoopsAttrs(rank),
1239  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1240  ValueRange blockArgs) {
1241  Value value = blockArgs[0];
1242  Type valueTy = value.getType();
1243 
1244  // For now we do all of our math in 64-bit. This is not optimal but
1245  // should be correct for now, consider computing correct bit depth
1246  // later.
1247  int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1248 
1249  auto inputZp = createConstFromIntAttribute<int32_t>(
1250  op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1251  nestedBuilder);
1252  auto outputZp = createConstFromIntAttribute<int32_t>(
1253  op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1254 
1255  Value multiplier = multiplierConstant ? multiplierConstant
1256  : blockArgs[multiplierArg];
1257  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1258 
1259  if (valueTy.getIntOrFloatBitWidth() < 32) {
1260  if (valueTy.isUnsignedInteger()) {
1261  value = nestedBuilder
1262  .create<UnrealizedConversionCastOp>(
1263  nestedLoc,
1264  nestedBuilder.getIntegerType(
1265  valueTy.getIntOrFloatBitWidth()),
1266  value)
1267  .getResult(0);
1268  value = nestedBuilder.create<arith::ExtUIOp>(
1269  nestedLoc, nestedBuilder.getI32Type(), value);
1270  } else {
1271  value = nestedBuilder.create<arith::ExtSIOp>(
1272  nestedLoc, nestedBuilder.getI32Type(), value);
1273  }
1274  }
1275 
1276  value =
1277  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1278 
1279  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1280  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1281  nestedBuilder.getBoolAttr(doubleRound));
1282 
1283  // Move to the new zero-point.
1284  value =
1285  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1286 
1287  // Saturate to the output size.
1288  IntegerType outIntType =
1289  blockArgs.back().getType().cast<IntegerType>();
1290  unsigned outBitWidth = outIntType.getWidth();
1291 
1292  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1293  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1294 
1295  // Unsigned integers have a difference output value.
1296  if (outIntType.isUnsignedInteger()) {
1297  intMin = 0;
1298  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1299  }
1300 
1301  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1302  loc, nestedBuilder.getI32IntegerAttr(intMin));
1303  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1304  loc, nestedBuilder.getI32IntegerAttr(intMax));
1305 
1306  value = clampHelper<arith::CmpIOp>(
1307  nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt,
1308  nestedBuilder);
1309 
1310  if (outIntType.getWidth() < 32) {
1311  value = nestedBuilder.create<arith::TruncIOp>(
1312  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1313  value);
1314 
1315  if (outIntType.isUnsignedInteger()) {
1316  value = nestedBuilder
1317  .create<UnrealizedConversionCastOp>(nestedLoc,
1318  outIntType, value)
1319  .getResult(0);
1320  }
1321  }
1322 
1323  nestedBuilder.create<linalg::YieldOp>(loc, value);
1324  });
1325 
1326  rewriter.replaceOp(op, linalgOp->getResults());
1327  return success();
1328  }
1329 };
1330 
1331 class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1332 public:
1334 
1335  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1336  PatternRewriter &rewriter) const final {
1337  Location loc = op.getLoc();
1338  auto input = op.input();
1339  auto inputTy = input.getType().cast<ShapedType>();
1340  auto resultTy = op.getType().cast<ShapedType>();
1341  auto resultElementTy = resultTy.getElementType();
1342 
1343  auto imageH = inputTy.getShape()[1];
1344  auto imageW = inputTy.getShape()[2];
1345 
1346  auto dynamicDimsOr =
1347  checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
1348  if (!dynamicDimsOr.hasValue())
1349  return failure();
1350  SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
1351 
1352  if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
1353  return failure();
1354 
1355  auto initTensor = rewriter.create<linalg::InitTensorOp>(
1356  loc, dynamicDims, resultTy.getShape(), resultElementTy);
1357 
1358  SmallVector<AffineMap, 2> affineMaps = {
1359  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1360 
1361  auto genericOp = rewriter.create<linalg::GenericOp>(
1362  loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
1363  getNParallelLoopsAttrs(resultTy.getRank()));
1364  rewriter.replaceOp(op, genericOp.getResult(0));
1365 
1366  OpBuilder::InsertionGuard regionGuard(rewriter);
1367  rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
1368  TypeRange({resultElementTy}), loc);
1369  Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
1370  Value y = rewriter.create<linalg::IndexOp>(loc, 1);
1371  Value x = rewriter.create<linalg::IndexOp>(loc, 2);
1372  Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
1373 
1374  auto hwMin =
1375  rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1376  auto hMax = rewriter.create<arith::ConstantOp>(
1377  loc, rewriter.getI32IntegerAttr(imageH - 1));
1378  auto wMax = rewriter.create<arith::ConstantOp>(
1379  loc, rewriter.getI32IntegerAttr(imageW - 1));
1380 
1381  Value inY =
1382  rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), y);
1383  Value inX =
1384  rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
1385 
1386  int32_t shift = op.shift();
1387  bool floatingPointMode = shift == 0;
1388 
1389  Value yStride, xStride, yOffset, xOffset;
1390  if (floatingPointMode) {
1391  yStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[0]);
1392  xStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[1]);
1393  yOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[0]);
1394  xOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[1]);
1395  } else {
1396  SmallVector<int32_t> stride, offset;
1397  getValuesFromIntArrayAttribute(op.stride(), stride);
1398  getValuesFromIntArrayAttribute(op.offset(), offset);
1399 
1400  yStride = rewriter.create<arith::ConstantOp>(
1401  loc, rewriter.getI32IntegerAttr(stride[0]));
1402  xStride = rewriter.create<arith::ConstantOp>(
1403  loc, rewriter.getI32IntegerAttr(stride[1]));
1404  yOffset = rewriter.create<arith::ConstantOp>(
1405  loc, rewriter.getI32IntegerAttr(offset[0]));
1406  xOffset = rewriter.create<arith::ConstantOp>(
1407  loc, rewriter.getI32IntegerAttr(offset[1]));
1408  }
1409 
1410  // Compute the the integer index and partial offset.
1411  // x = x * stride + offset;
1412  // ix = floor(x)
1413  // dx = x - ix
1414  Value ix, iy, dx, dy;
1415  if (floatingPointMode) {
1416  Value y =
1417  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY);
1418  Value x =
1419  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX);
1420 
1421  y = rewriter.create<arith::MulFOp>(loc, y, yStride);
1422  x = rewriter.create<arith::MulFOp>(loc, x, xStride);
1423 
1424  y = rewriter.create<arith::AddFOp>(loc, y, yOffset);
1425  x = rewriter.create<arith::AddFOp>(loc, x, xOffset);
1426 
1427  iy = rewriter.create<math::FloorOp>(loc, y);
1428  ix = rewriter.create<math::FloorOp>(loc, x);
1429 
1430  dy = rewriter.create<arith::SubFOp>(loc, y, iy);
1431  dx = rewriter.create<arith::SubFOp>(loc, x, ix);
1432 
1433  iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy);
1434  ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix);
1435  } else {
1436  Value shiftVal = rewriter.create<arith::ConstantOp>(
1437  loc, rewriter.getI32IntegerAttr(shift));
1438 
1439  Value y = rewriter.create<arith::MulIOp>(loc, inY, yStride);
1440  Value x = rewriter.create<arith::MulIOp>(loc, inX, xStride);
1441 
1442  y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
1443  x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
1444 
1445  iy = rewriter.create<arith::ShRSIOp>(loc, y, shiftVal);
1446  ix = rewriter.create<arith::ShRSIOp>(loc, x, shiftVal);
1447 
1448  Value yTrunc = rewriter.create<arith::ShLIOp>(loc, iy, shiftVal);
1449  Value xTrunc = rewriter.create<arith::ShLIOp>(loc, ix, shiftVal);
1450 
1451  dy = rewriter.create<arith::SubIOp>(loc, y, yTrunc);
1452  dx = rewriter.create<arith::SubIOp>(loc, x, xTrunc);
1453  }
1454 
1455  if (op.mode() == "NEAREST_NEIGHBOR") {
1456  Value yPred, xPred;
1457  // Round the index position towards the closest pixel location.
1458  if (floatingPointMode) {
1459  auto halfVal = rewriter.create<arith::ConstantOp>(
1460  loc, rewriter.getF32FloatAttr(0.5f));
1461  yPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1462  dy, halfVal);
1463  xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1464  dx, halfVal);
1465  } else {
1466  auto halfVal = rewriter.create<arith::ConstantOp>(
1467  loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
1468  yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1469  dy, halfVal);
1470  xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1471  dx, halfVal);
1472  }
1473 
1474  auto zeroVal = rewriter.create<arith::ConstantOp>(
1475  loc, rewriter.getI32IntegerAttr(0));
1476  auto oneVal = rewriter.create<arith::ConstantOp>(
1477  loc, rewriter.getI32IntegerAttr(1));
1478 
1479  auto yOffset =
1480  rewriter.create<arith::SelectOp>(loc, yPred, oneVal, zeroVal);
1481  auto xOffset =
1482  rewriter.create<arith::SelectOp>(loc, xPred, oneVal, zeroVal);
1483 
1484  iy = rewriter.create<arith::AddIOp>(loc, iy, yOffset);
1485  ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);
1486 
1487  // Clamp the to be within the bounds of the input image.
1488 
1489  iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
1490  arith::CmpIPredicate::slt, rewriter);
1491  ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
1492  arith::CmpIPredicate::slt, rewriter);
1493 
1494  // Read the value from the input array.
1495  iy =
1496  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), iy);
1497  ix =
1498  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), ix);
1499 
1500  Value result = rewriter.create<tensor::ExtractOp>(
1501  loc, input, ValueRange{batch, iy, ix, channel});
1502 
1503  rewriter.create<linalg::YieldOp>(loc, result);
1504 
1505  return success();
1506  }
1507 
1508  if (op.mode() == "BILINEAR") {
1509  Value y0 = iy;
1510  Value x0 = ix;
1511 
1512  auto oneVal = rewriter.create<arith::ConstantOp>(
1513  loc, rewriter.getI32IntegerAttr(1));
1514  Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
1515  Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
1516 
1517  y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
1518  arith::CmpIPredicate::slt, rewriter);
1519  y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
1520  arith::CmpIPredicate::slt, rewriter);
1521 
1522  x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
1523  arith::CmpIPredicate::slt, rewriter);
1524  x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
1525  arith::CmpIPredicate::slt, rewriter);
1526 
1527  y0 =
1528  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);
1529  y1 =
1530  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y1);
1531  x0 =
1532  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x0);
1533  x1 =
1534  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x1);
1535 
1536  Value y0x0 = rewriter.create<tensor::ExtractOp>(
1537  loc, input, ValueRange{batch, y0, x0, channel});
1538  Value y0x1 = rewriter.create<tensor::ExtractOp>(
1539  loc, input, ValueRange{batch, y0, x1, channel});
1540  Value y1x0 = rewriter.create<tensor::ExtractOp>(
1541  loc, input, ValueRange{batch, y1, x0, channel});
1542  Value y1x1 = rewriter.create<tensor::ExtractOp>(
1543  loc, input, ValueRange{batch, y1, x1, channel});
1544 
1545  if (floatingPointMode) {
1546  auto oneVal = rewriter.create<arith::ConstantOp>(
1547  loc, rewriter.getF32FloatAttr(1.f));
1548  Value rightPart = dx;
1549  Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
1550 
1551  y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
1552  y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
1553  Value topAcc = rewriter.create<arith::AddFOp>(loc, y0x0, y0x1);
1554 
1555  y1x0 = rewriter.create<arith::MulFOp>(loc, y1x0, leftPart);
1556  y1x1 = rewriter.create<arith::MulFOp>(loc, y1x1, rightPart);
1557  Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
1558 
1559  Value bottomPart = dy;
1560  Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
1561  topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
1562  bottomAcc = rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
1563  Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
1564 
1565  rewriter.create<linalg::YieldOp>(loc, result);
1566  return success();
1567  }
1568  y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
1569  y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
1570  y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
1571  y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
1572 
1573  if (resultElementTy.getIntOrFloatBitWidth() > 32) {
1574  dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
1575  dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
1576  }
1577 
1578  auto unitVal = rewriter.create<arith::ConstantOp>(
1579  loc, rewriter.getIntegerAttr(resultElementTy, 1LL << shift));
1580  Value rightPart = dx;
1581  Value leftPart = rewriter.create<arith::SubIOp>(loc, unitVal, dx);
1582 
1583  y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
1584  y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
1585  Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
1586 
1587  y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
1588  y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
1589  Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
1590 
1591  Value bottomPart = dy;
1592  Value topPart = rewriter.create<arith::SubIOp>(loc, unitVal, dy);
1593  topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
1594  bottomAcc = rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
1595  Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
1596 
1597  rewriter.create<linalg::YieldOp>(loc, result);
1598  return success();
1599  }
1600  return failure();
1601  }
1602 };
1603 
1604 // At the codegen level any identity operations should be removed. Any cases
1605 // where identity is load-bearing (e.g. cross device computation) should be
1606 // handled before lowering to codegen.
1607 template <typename SrcOp>
1608 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1609 public:
1611 
1612  LogicalResult matchAndRewrite(SrcOp op,
1613  PatternRewriter &rewriter) const final {
1614  rewriter.replaceOp(op, op.getOperation()->getOperands());
1615  return success();
1616  }
1617 };
1618 
1619 template <typename SrcOp>
1620 class ReduceConverter : public OpRewritePattern<SrcOp> {
1621 public:
1623 
1624  LogicalResult matchAndRewrite(SrcOp reduceOp,
1625  PatternRewriter &rewriter) const final {
1626  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
1627  }
1628 };
1629 
1630 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
1632 
1634  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
1635  ConversionPatternRewriter &rewriter) const override {
1636  auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
1637  auto resultType = op.getType().dyn_cast<RankedTensorType>();
1638 
1639  Location loc = op.getLoc();
1640  int axis = op.axis();
1641  Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
1642  loc, rewriter.getIndexAttr(axis));
1643  int rank = resultType.getRank();
1644  SmallVector<Value, 3> offsets, sizes, strides;
1645  sizes.reserve(rank);
1646  strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
1647  offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
1648 
1649  SmallVector<Value> dynDims;
1650  for (int i = 0; i < rank; ++i) {
1651  sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
1652  loc, adaptor.getOperands()[0], i));
1653  if (inputType.isDynamicDim(i)) {
1654  dynDims.push_back(
1655  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
1656  }
1657  }
1658 
1659  Value resultDimSize = sizes[axis];
1660  for (auto arg : adaptor.getOperands().drop_front()) {
1661  auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
1662  resultDimSize =
1663  rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
1664  }
1665  sizes[axis] = resultDimSize;
1666 
1667  Value init = rewriter.create<linalg::InitTensorOp>(
1668  loc, dynDims, resultType.getShape(), resultType.getElementType());
1669 
1670  Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
1671  loc, rewriter.getZeroAttr(resultType.getElementType()));
1672  Value result =
1673  rewriter
1674  .create<linalg::FillOp>(loc, ValueRange{zeroVal}, ValueRange{init})
1675  .result();
1676 
1677  auto toOpFoldResult = [](Value v) -> OpFoldResult {
1678  auto op = v.getDefiningOp<arith::ConstantIndexOp>();
1679  if (!op)
1680  return v;
1681  return op.getValue();
1682  };
1683  for (auto arg : adaptor.getOperands()) {
1684  sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
1685  result = rewriter.createOrFold<tensor::InsertSliceOp>(
1686  loc, arg, result,
1687  llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
1688  llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
1689  llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
1690  offsets[axis] =
1691  rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
1692  }
1693  rewriter.replaceOp(op, result);
1694  return success();
1695  }
1696 };
1697 
1698 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1699 public:
1701 
1702  LogicalResult matchAndRewrite(tosa::ReverseOp op,
1703  PatternRewriter &rewriter) const final {
1704  auto loc = op.getLoc();
1705  Value input = op.input();
1706  auto inputTy = input.getType().template cast<ShapedType>();
1707  auto resultTy = op.getType().template cast<ShapedType>();
1708  auto axis = op.axis();
1709 
1710  SmallVector<Value> dynDims;
1711  for (int i = 0; i < inputTy.getRank(); i++) {
1712  if (inputTy.isDynamicDim(i)) {
1713  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1714  }
1715  }
1716 
1717  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1718 
1719  // First fill the output buffer with the init value.
1720  auto initTensor = rewriter
1721  .create<linalg::InitTensorOp>(
1722  loc, ArrayRef<Value>({dynDims}),
1723  inputTy.getShape(), inputTy.getElementType())
1724  .result();
1725  SmallVector<AffineMap, 2> affineMaps = {
1726  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1727 
1728  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1729  op, resultTy, ArrayRef<Value>({}), ValueRange{initTensor}, affineMaps,
1730  getNParallelLoopsAttrs(resultTy.getRank()),
1731  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1732  llvm::SmallVector<Value> indices;
1733  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1734  auto index =
1735  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1736  if (i == axis) {
1737  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1738  auto sizeMinusOne =
1739  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1740  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1741  index);
1742  }
1743 
1744  indices.push_back(index);
1745  }
1746 
1747  auto extract = nestedBuilder.create<tensor::ExtractOp>(
1748  nestedLoc, input, indices);
1749  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1750  extract.getResult());
1751  });
1752  return success();
1753  }
1754 };
1755 
1756 // This converter translate a tile operation to a reshape, broadcast, reshape.
1757 // The first reshape minimally expands each tiled dimension to include a
1758 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1759 // multiple.
1760 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1762 
1764  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1765  ConversionPatternRewriter &rewriter) const override {
1766  auto loc = op.getLoc();
1767  auto input = op.input1();
1768  auto inputTy = input.getType().cast<ShapedType>();
1769  auto inputShape = inputTy.getShape();
1770  auto resultTy = op.getType().cast<ShapedType>();
1771  auto elementTy = inputTy.getElementType();
1772  int64_t rank = inputTy.getRank();
1773 
1774  SmallVector<int64_t> multiples;
1775  getValuesFromIntArrayAttribute(op.multiples(), multiples);
1776 
1777  // Broadcast the newly added dimensions to their appropriate multiple.
1778  SmallVector<int64_t, 2> genericShape;
1779  for (int i = 0; i < rank; i++) {
1780  genericShape.push_back(multiples[i]);
1781  genericShape.push_back(inputShape[i]);
1782  }
1783 
1784  SmallVector<Value> dynDims;
1785  for (int i = 0; i < inputTy.getRank(); i++) {
1786  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1787  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1788  }
1789  }
1790 
1791  auto initTensor = rewriter.create<linalg::InitTensorOp>(
1792  op.getLoc(), dynDims, genericShape, elementTy);
1793 
1794  // We needs to map the input shape to the non-broadcasted dimensions.
1795  SmallVector<AffineExpr, 4> dimExprs;
1796  dimExprs.reserve(rank);
1797  for (unsigned i = 0; i < rank; ++i)
1798  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1799 
1800  auto readAffineMap =
1801  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1802  rewriter.getContext());
1803 
1804  SmallVector<AffineMap, 2> affineMaps = {
1805  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1806 
1807  auto genericOp = rewriter.create<linalg::GenericOp>(
1808  loc, RankedTensorType::get(genericShape, elementTy), input,
1809  ValueRange{initTensor}, affineMaps,
1810  getNParallelLoopsAttrs(genericShape.size()),
1811  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1812  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1813  });
1814 
1815  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1816  op, resultTy, genericOp.getResult(0),
1817  rewriter.getI64ArrayAttr(resultTy.getShape()));
1818  return success();
1819  }
1820 };
1821 
1822 class PadConverter : public OpRewritePattern<tosa::PadOp> {
1823 public:
1825 
1826  LogicalResult matchAndRewrite(tosa::PadOp padOp,
1827  PatternRewriter &rewriter) const final {
1828  auto loc = padOp.getLoc();
1829  auto input = padOp.input1();
1830  auto padding = padOp.padding();
1831 
1832  ShapedType inputTy = input.getType().cast<ShapedType>();
1833  Type elementTy = inputTy.getElementType();
1834  int64_t rank = inputTy.getRank();
1835 
1836  // Setup the default constantAttr.
1837 
1838  Value padConstant;
1839 
1840  if (padOp.pad_const()) {
1841  padConstant = rewriter.createOrFold<tensor::ExtractOp>(
1842  loc, padOp.pad_const(), ValueRange({}));
1843  } else {
1844  Attribute constantAttr;
1845  if (elementTy.isa<FloatType>()) {
1846  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
1847  } else if (elementTy.isa<IntegerType>() && !padOp.quantization_info()) {
1848  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
1849  } else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
1850  int64_t value = padOp.quantization_info()->getInputZp();
1851  constantAttr = rewriter.getIntegerAttr(elementTy, value);
1852  }
1853  if (constantAttr)
1854  padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
1855  }
1856 
1857  if (!padConstant) {
1858  return rewriter.notifyMatchFailure(
1859  padOp, "tosa.pad was unable to determine the pad constant value.");
1860  }
1861 
1862  Value lowIndex =
1863  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
1864  Value highIndex =
1865  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
1866 
1867  SmallVector<OpFoldResult, 3> lowValues;
1868  SmallVector<OpFoldResult, 3> highValues;
1869 
1870  lowValues.reserve(rank);
1871  highValues.reserve(rank);
1872 
1873  for (int i = 0; i < rank; i++) {
1874  Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
1875  Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
1876  loc, padding, ValueRange({inputIndex, lowIndex}));
1877  Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
1878  loc, padding, ValueRange({inputIndex, highIndex}));
1879 
1880  lowVal = rewriter.createOrFold<arith::IndexCastOp>(
1881  loc, rewriter.getIndexType(), lowVal);
1882  highVal = rewriter.createOrFold<arith::IndexCastOp>(
1883  loc, rewriter.getIndexType(), highVal);
1884 
1885  lowValues.push_back(lowVal);
1886  highValues.push_back(highVal);
1887  }
1888 
1889  auto newPadOp = tensor::createPadScalarOp(
1890  padOp.getType(), input, padConstant, lowValues, highValues,
1891  /*nofold=*/false, loc, rewriter);
1892 
1893  rewriter.replaceOp(padOp, newPadOp.getResult());
1894  return success();
1895  }
1896 };
1897 
1898 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1899 // op, producing two output buffers.
1900 //
1901 // The first output buffer contains the index of the found maximum value. It is
1902 // initialized to 0 and is resulting integer type.
1903 //
1904 // The second output buffer contains the maximum value found. It is initialized
1905 // to the minimum representable value of the input element type. After being
1906 // populated by indexed_generic, this buffer is disgarded as only the index is
1907 // requested.
1908 //
1909 // The indexed_generic op updates both the maximum value and index if the
1910 // current value exceeds the running max.
1911 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1912 public:
1914 
1915  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1916  PatternRewriter &rewriter) const final {
1917  auto loc = argmaxOp.getLoc();
1918  Value input = argmaxOp.input();
1919  auto inputTy = input.getType().cast<ShapedType>();
1920  auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
1921  auto inElementTy = inputTy.getElementType();
1922  auto outElementTy = resultTy.getElementType();
1923  int axis = argmaxOp.axis();
1924  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1925 
1926  if (!outElementTy.isa<IntegerType>())
1927  return rewriter.notifyMatchFailure(
1928  argmaxOp,
1929  "tosa.arg_max to linalg.* requires integer-like result type");
1930 
1931  SmallVector<Value> dynDims;
1932  for (int i = 0; i < inputTy.getRank(); i++) {
1933  if (inputTy.isDynamicDim(i) && i != axis) {
1934  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1935  }
1936  }
1937 
1938  // First fill the output buffer for the index.
1939  auto initTensorIdx =
1940  rewriter
1941  .create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
1942  outElementTy)
1943  .result();
1944  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1945  loc, rewriter.getIntegerAttr(outElementTy, 0));
1946  auto filledTensorIdx =
1947  rewriter
1948  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
1949  ValueRange{initTensorIdx})
1950  .result();
1951 
1952  // Second fill the output buffer for the running max.
1953  auto initTensorMax = rewriter
1954  .create<linalg::InitTensorOp>(
1955  loc, dynDims, resultTy.getShape(), inElementTy)
1956  .result();
1957  auto fillValueMaxAttr =
1958  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
1959 
1960  if (!fillValueMaxAttr)
1961  return rewriter.notifyMatchFailure(
1962  argmaxOp, "unsupported tosa.argmax element type");
1963 
1964  auto fillValueMax =
1965  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
1966  auto filledTensorMax =
1967  rewriter
1968  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
1969  ValueRange{initTensorMax})
1970  .result();
1971 
1972  // We need to reduce along the arg-max axis, with parallel operations along
1973  // the rest.
1974  SmallVector<StringRef, 4> iteratorTypes;
1975  iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
1976  iteratorTypes[axis] = getReductionIteratorTypeName();
1977 
1978  SmallVector<AffineExpr, 2> srcExprs;
1979  SmallVector<AffineExpr, 2> dstExprs;
1980  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
1981  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1982  if (axis != i)
1983  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1984  }
1985 
1986  bool didEncounterError = false;
1987  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
1988  auto linalgOp = rewriter.create<linalg::GenericOp>(
1989  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
1990  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
1991  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1992  ValueRange blockArgs) {
1993  auto newValue = blockArgs[0];
1994  auto oldIndex = blockArgs[1];
1995  auto oldValue = blockArgs[2];
1996 
1997  Value newIndex = rewriter.create<arith::IndexCastOp>(
1998  nestedLoc, oldIndex.getType(),
1999  rewriter.create<linalg::IndexOp>(loc, axis));
2000 
2001  Value predicate;
2002  if (inElementTy.isa<FloatType>()) {
2003  predicate = rewriter.create<arith::CmpFOp>(
2004  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2005  } else if (inElementTy.isa<IntegerType>()) {
2006  predicate = rewriter.create<arith::CmpIOp>(
2007  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2008  } else {
2009  didEncounterError = true;
2010  return;
2011  }
2012 
2013  auto resultMax = rewriter.create<arith::SelectOp>(
2014  nestedLoc, predicate, newValue, oldValue);
2015  auto resultIndex = rewriter.create<arith::SelectOp>(
2016  nestedLoc, predicate, newIndex, oldIndex);
2017  nestedBuilder.create<linalg::YieldOp>(
2018  nestedLoc, ValueRange({resultIndex, resultMax}));
2019  });
2020 
2021  if (didEncounterError)
2022  return rewriter.notifyMatchFailure(
2023  argmaxOp, "unsupported tosa.argmax element type");
2024 
2025  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2026  return success();
2027  }
2028 };
2029 
2030 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2031 public:
2034  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2035  ConversionPatternRewriter &rewriter) const final {
2036  auto input = adaptor.getOperands()[0];
2037  auto indices = adaptor.getOperands()[1];
2038 
2039  auto resultTy = op.getType().cast<ShapedType>();
2040 
2041  auto dynamicDimsOr =
2042  checkHasDynamicBatchDims(rewriter, op, {input, indices, op.output()});
2043  if (!dynamicDimsOr.hasValue())
2044  return failure();
2045  SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
2046 
2047  auto resultElementTy = resultTy.getElementType();
2048 
2049  auto loc = op.getLoc();
2050 
2051  auto initTensor =
2052  rewriter
2053  .create<linalg::InitTensorOp>(loc, dynamicDims, resultTy.getShape(),
2054  resultElementTy)
2055  .result();
2056 
2057  SmallVector<AffineMap, 2> affineMaps = {
2059  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2060  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2061  rewriter.getContext()),
2062  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2063 
2064  auto genericOp = rewriter.create<linalg::GenericOp>(
2065  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2066  ValueRange{initTensor}, affineMaps,
2067  getNParallelLoopsAttrs(resultTy.getRank()),
2068  [&](OpBuilder &b, Location loc, ValueRange args) {
2069  auto indexValue = args[0];
2070  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2071  Value index1 = rewriter.create<arith::IndexCastOp>(
2072  loc, rewriter.getIndexType(), indexValue);
2073  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2074  Value extract = rewriter.create<tensor::ExtractOp>(
2075  loc, input, ValueRange{index0, index1, index2});
2076  rewriter.create<linalg::YieldOp>(loc, extract);
2077  });
2078  rewriter.replaceOp(op, genericOp.getResult(0));
2079  return success();
2080  }
2081 };
2082 
2083 // Lowerings the TableOp to a series of gathers and numerica operations. This
2084 // includes interpolation between the high/low values. For the I8 varient, this
2085 // simplifies to a single gather operation.
2086 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2087 public:
2089 
2090  LogicalResult matchAndRewrite(tosa::TableOp op,
2091  PatternRewriter &rewriter) const final {
2092  auto loc = op.getLoc();
2093  Value input = op.input();
2094  Value table = op.table();
2095  auto inputTy = input.getType().cast<ShapedType>();
2096  auto tableTy = table.getType().cast<ShapedType>();
2097  auto resultTy = op.getType().cast<ShapedType>();
2098 
2099  auto inputElementTy = inputTy.getElementType();
2100  auto tableElementTy = tableTy.getElementType();
2101  auto resultElementTy = resultTy.getElementType();
2102 
2103  SmallVector<Value> dynDims;
2104  for (int i = 0; i < resultTy.getRank(); ++i) {
2105  if (inputTy.isDynamicDim(i)) {
2106  dynDims.push_back(
2107  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2108  }
2109  }
2110 
2111  auto initTensor =
2112  rewriter
2113  .create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
2114  resultElementTy)
2115  .result();
2116 
2117  SmallVector<AffineMap, 2> affineMaps = {
2118  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2119  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2120 
2121  auto genericOp = rewriter.create<linalg::GenericOp>(
2122  loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
2123  getNParallelLoopsAttrs(resultTy.getRank()));
2124  rewriter.replaceOp(op, genericOp.getResult(0));
2125 
2126  {
2127  OpBuilder::InsertionGuard regionGuard(rewriter);
2128  Block *block = rewriter.createBlock(
2129  &genericOp.region(), genericOp.region().end(),
2130  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2131 
2132  auto inputValue = block->getArgument(0);
2133  rewriter.setInsertionPointToStart(block);
2134  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2135  resultElementTy.isInteger(8)) {
2136  Value index = rewriter.create<arith::IndexCastOp>(
2137  loc, rewriter.getIndexType(), inputValue);
2138  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2139  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2140  index, offset);
2141  Value extract =
2142  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2143  rewriter.create<linalg::YieldOp>(loc, extract);
2144  return success();
2145  }
2146 
2147  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2148  resultElementTy.isInteger(32)) {
2149  Value extend = rewriter.create<arith::ExtSIOp>(
2150  loc, rewriter.getI32Type(), inputValue);
2151 
2152  auto offset = rewriter.create<arith::ConstantOp>(
2153  loc, rewriter.getI32IntegerAttr(32768));
2154  auto seven = rewriter.create<arith::ConstantOp>(
2155  loc, rewriter.getI32IntegerAttr(7));
2156  auto one = rewriter.create<arith::ConstantOp>(
2157  loc, rewriter.getI32IntegerAttr(1));
2158  auto b1111111 = rewriter.create<arith::ConstantOp>(
2159  loc, rewriter.getI32IntegerAttr(127));
2160 
2161  // Compute the index and fractional part from the input value:
2162  // value = value + 32768
2163  // index = value >> 7;
2164  // fraction = 0x01111111 & value
2165  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2166  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2167  Value fraction =
2168  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2169 
2170  // Extract the base and next values from the table.
2171  // base = (int32_t) table[index];
2172  // next = (int32_t) table[index + 1];
2173  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2174 
2175  index = rewriter.create<arith::IndexCastOp>(
2176  loc, rewriter.getIndexType(), index);
2177  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2178  loc, rewriter.getIndexType(), indexPlusOne);
2179 
2180  Value base =
2181  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2182  Value next = rewriter.create<tensor::ExtractOp>(
2183  loc, table, ValueRange{indexPlusOne});
2184 
2185  base =
2186  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2187  next =
2188  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2189 
2190  // Use the fractional part to interpolate between the input values:
2191  // result = (base << 7) + (next - base) * fraction
2192  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2193  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2194  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2195  Value result =
2196  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2197 
2198  rewriter.create<linalg::YieldOp>(loc, result);
2199 
2200  return success();
2201  }
2202  }
2203 
2204  return rewriter.notifyMatchFailure(
2205  op, "unable to create body for tosa.table op");
2206  }
2207 };
2208 
2209 } // namespace
2210 
2212  RewritePatternSet *patterns) {
2213  patterns->add<
2214  // clang-format off
2215  PointwiseConverter<tosa::AddOp>,
2216  PointwiseConverter<tosa::SubOp>,
2217  PointwiseConverter<tosa::MulOp>,
2218  PointwiseConverter<tosa::DivOp>,
2219  PointwiseConverter<tosa::NegateOp>,
2220  PointwiseConverter<tosa::PowOp>,
2221  PointwiseConverter<tosa::ReciprocalOp>,
2222  PointwiseConverter<tosa::RsqrtOp>,
2223  PointwiseConverter<tosa::LogOp>,
2224  PointwiseConverter<tosa::ExpOp>,
2225  PointwiseConverter<tosa::AbsOp>,
2226  PointwiseConverter<tosa::TanhOp>,
2227  PointwiseConverter<tosa::BitwiseAndOp>,
2228  PointwiseConverter<tosa::BitwiseOrOp>,
2229  PointwiseConverter<tosa::BitwiseNotOp>,
2230  PointwiseConverter<tosa::BitwiseXorOp>,
2231  PointwiseConverter<tosa::LogicalAndOp>,
2232  PointwiseConverter<tosa::LogicalNotOp>,
2233  PointwiseConverter<tosa::LogicalOrOp>,
2234  PointwiseConverter<tosa::LogicalXorOp>,
2235  PointwiseConverter<tosa::CastOp>,
2236  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2237  PointwiseConverter<tosa::LogicalRightShiftOp>,
2238  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2239  PointwiseConverter<tosa::ClzOp>,
2240  PointwiseConverter<tosa::SelectOp>,
2241  PointwiseConverter<tosa::GreaterOp>,
2242  PointwiseConverter<tosa::GreaterEqualOp>,
2243  PointwiseConverter<tosa::EqualOp>,
2244  PointwiseConverter<tosa::MaximumOp>,
2245  PointwiseConverter<tosa::MinimumOp>,
2246  PointwiseConverter<tosa::CeilOp>,
2247  PointwiseConverter<tosa::FloorOp>,
2248  PointwiseConverter<tosa::ClampOp>,
2249  PointwiseConverter<tosa::ReluNOp>,
2250  PointwiseConverter<tosa::SigmoidOp>,
2251  IdentityNConverter<tosa::IdentityOp>,
2252  ReduceConverter<tosa::ReduceAllOp>,
2253  ReduceConverter<tosa::ReduceAnyOp>,
2254  ReduceConverter<tosa::ReduceMinOp>,
2255  ReduceConverter<tosa::ReduceMaxOp>,
2256  ReduceConverter<tosa::ReduceSumOp>,
2257  ReduceConverter<tosa::ReduceProdOp>,
2258  ArgMaxConverter,
2259  ConcatConverter,
2260  GatherConverter,
2261  PadConverter,
2262  ReshapeConverterCollapse,
2263  ReshapeConverterExpand,
2264  ReshapeConverterCollapseExpand,
2265  RescaleConverter,
2266  ResizeConverter,
2267  ReverseConverter,
2268  TableConverter,
2269  TileConverter,
2270  TransposeConverter>(patterns->getContext());
2271  // clang-format on
2272 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
U cast() const
Definition: Location.h:67
MLIRContext * getContext() const
Definition: Builders.h:54
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
U cast() const
Definition: Attributes.h:130
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:458
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:308
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:302
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:43
Block represents an ordered list of Operations.
Definition: Block.h:29
static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:229
Value getOperand(unsigned idx)
Definition: Operation.h:274
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
unsigned getNumOperands()
Definition: Operation.h:270
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:688
void getValuesFromIntArrayAttribute(ArrayAttr attr, SmallVector< T > &arrayValues)
FloatType getF32Type()
Definition: Builders.cpp:40
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
SmallVector< StringRef > getNParallelLoopsAttrs(unsigned nParallelLoops)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
Optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:61
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:256
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:166
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
PadOp createPadScalarOp(Type type, Value source, Value pad, ArrayRef< OpFoldResult > low, ArrayRef< OpFoldResult > high, bool nofold, Location loc, OpBuilder &builder)
Definition: Utils.cpp:21
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:331
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:161
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:259
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
static bool findIntermediateShape(ArrayRef< int64_t > lhsShape, ArrayRef< int64_t > rhsShape, SmallVector< int64_t > &intermediateShape, bool isDynamic)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, PatternRewriter &rewriter)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static SmallVector< int64_t, 8 > subtract(ArrayRef< int64_t > vecA, ArrayRef< int64_t > vecB)
static int resultIndex(int i)
Definition: Operator.cpp:308
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:369
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:286
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
Type getType() const
Return the type of this value.
Definition: Value.h:118
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
IndexType getIndexType()
Definition: Builders.cpp:48
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:333
static bool createReassociationMapsForCollapse(PatternRewriter &rewriter, ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape, SmallVector< ReassociationExprs, 4 > &reassociationMap, bool isDynamic)
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:285
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:328
SlowMPInt abs(const SlowMPInt &x)
Redeclarations of friend declarations above to make it discoverable by lookups.
Definition: SlowMPInt.cpp:200
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter)
bool isa() const
Definition: Types.h:246
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:185
result_range getResults()
Definition: Operation.h:339
This class helps build Operations.
Definition: Builders.h:184
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:378
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
MLIRContext * getContext() const
IntegerType getI32Type()
Definition: Builders.cpp:54
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:262