MLIR  14.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/Dialect/SCF/SCF.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
28 
29 #include <numeric>
30 
31 using namespace mlir;
32 using namespace mlir::tosa;
33 
34 template <typename T>
35 static arith::ConstantOp
36 createConstFromIntAttribute(Operation *op, const std::string &attrName,
37  Type requiredAttrType, OpBuilder &rewriter) {
38  auto castedN = static_cast<T>(
39  op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
40  return rewriter.create<arith::ConstantOp>(
41  op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
42 }
43 
44 static Value
46  ArrayRef<Type> resultTypes,
47  PatternRewriter &rewriter) {
48  Location loc = op->getLoc();
49  auto elementTy =
50  op->getOperand(0).getType().cast<ShapedType>().getElementType();
51 
52  // tosa::AbsOp
53  if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
54  return rewriter.create<math::AbsOp>(loc, resultTypes, args);
55 
56  if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
57  auto zero = rewriter.create<arith::ConstantOp>(
58  loc, rewriter.getZeroAttr(elementTy));
59  auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
60  args[0], zero);
61  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
62  return rewriter.create<mlir::SelectOp>(loc, cmp, args[0], neg);
63  }
64 
65  // tosa::AddOp
66  if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
67  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
68 
69  if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
70  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
71 
72  // tosa::SubOp
73  if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
74  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
75 
76  if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
77  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
78 
79  // tosa::MulOp
80  if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
81  if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
82  (void)rewriter.notifyMatchFailure(op,
83  "Cannot have shift value for float");
84  return nullptr;
85  }
86  return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
87  }
88 
89  // tosa::DivOp
90  if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
91  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
92 
93  // tosa::ReciprocalOp
94  if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
95  auto one =
96  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
97  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
98  }
99 
100  if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
101  Value a = args[0];
102  Value b = args[1];
103  auto shift =
104  op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
105  if (shift > 0) {
106  auto shiftConst =
107  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
108  if (!a.getType().isInteger(32))
109  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
110 
111  if (!b.getType().isInteger(32))
112  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
113 
114  auto result = rewriter.create<tosa::ApplyScaleOp>(
115  loc, rewriter.getI32Type(), a, b, shiftConst,
116  rewriter.getBoolAttr(false));
117 
118  if (elementTy.isInteger(32))
119  return result;
120 
121  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
122  }
123 
124  int aWidth = a.getType().getIntOrFloatBitWidth();
125  int bWidth = b.getType().getIntOrFloatBitWidth();
126  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
127 
128  if (aWidth < cWidth)
129  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
130  if (bWidth < cWidth)
131  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
132 
133  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
134  }
135 
136  // tosa::NegateOp
137  if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
138  return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
139 
140  if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
141  !cast<tosa::NegateOp>(op).quantization_info()) {
142  auto constant =
143  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
144  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
145  }
146 
147  if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
148  cast<tosa::NegateOp>(op).quantization_info()) {
149  auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
150  int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
151  int64_t inZp =
152  quantizationInfo.getValue().input_zp().getValue().getSExtValue();
153  int64_t outZp =
154  quantizationInfo.getValue().output_zp().getValue().getSExtValue();
155 
156  // Compute the maximum value that can occur in the intermediate buffer.
157  int64_t zpAdd = inZp + outZp;
158  int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
159  std::abs(zpAdd) + 1;
160 
161  // Convert that maximum value into the maximum bitwidth needed to represent
162  // it. We assume 48-bit numbers may be supported further in the pipeline.
163  int intermediateBitWidth = 64;
164  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
165  intermediateBitWidth = 16;
166  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
167  intermediateBitWidth = 32;
168  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
169  intermediateBitWidth = 48;
170  }
171 
172  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
173  Value zpAddValue = rewriter.create<arith::ConstantOp>(
174  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
175 
176  // The negation can be applied by doing:
177  // outputValue = inZp + outZp - inputValue
178  auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
179  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
180 
181  // Clamp to the negation range.
182  auto min = rewriter.create<arith::ConstantIntOp>(
183  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
184  intermediateType);
185  auto max = rewriter.create<arith::ConstantIntOp>(
186  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
187  intermediateType);
188  auto clamp = clampHelper<arith::CmpIOp>(
189  loc, sub, min, max, arith::CmpIPredicate::slt, rewriter);
190 
191  // Truncate to the final value.
192  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
193  }
194 
195  // tosa::BitwiseAndOp
196  if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
197  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
198 
199  // tosa::BitwiseOrOp
200  if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
201  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
202 
203  // tosa::BitwiseNotOp
204  if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
205  auto allOnesAttr = rewriter.getIntegerAttr(
206  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
207  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
208  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
209  }
210 
211  // tosa::BitwiseXOrOp
212  if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
213  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
214 
215  // tosa::LogicalLeftShiftOp
216  if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
217  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
218 
219  // tosa::LogicalRightShiftOp
220  if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
221  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
222 
223  // tosa::ArithmeticRightShiftOp
224  if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
225  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
226  auto round = op->getAttr("round").cast<BoolAttr>().getValue();
227  if (!round) {
228  return result;
229  }
230 
231  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
232  auto one =
233  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
234  auto zero =
235  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
236  auto i1one =
237  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
238 
239  // Checking that input2 != 0
240  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
241  loc, arith::CmpIPredicate::sgt, args[1], zero);
242 
243  // Checking for the last bit of input1 to be 1
244  auto subtract =
245  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
246  auto shifted =
247  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
248  ->getResults();
249  auto truncated =
250  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, mlir::None);
251  auto isInputOdd =
252  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
253 
254  auto shouldRound = rewriter.create<arith::AndIOp>(
255  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
256  auto extended =
257  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
258  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
259  }
260 
261  // tosa::ClzOp
262  if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
263  int bitWidth = elementTy.getIntOrFloatBitWidth();
264  auto zero =
265  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
266  auto leadingZeros = rewriter.create<arith::ConstantOp>(
267  loc, IntegerAttr::get(elementTy, bitWidth));
268 
269  SmallVector<Value> operands = {args[0], leadingZeros, zero};
270  SmallVector<Type> types = {elementTy, elementTy, elementTy};
271  SmallVector<Location> locations = {loc, loc, loc};
272 
273  auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
274  Block *before =
275  rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
276  Block *after =
277  rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
278 
279  // The conditional block of the while loop.
280  {
281  rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
282  Value input = before->getArgument(0);
283  Value zero = before->getArgument(2);
284 
285  Value inputLargerThanZero = rewriter.create<arith::CmpIOp>(
286  loc, arith::CmpIPredicate::ne, input, zero);
287  rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
288  before->getArguments());
289  }
290 
291  // The body of the while loop: shift right until reaching a value of 0.
292  {
293  rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
294  Value input = after->getArgument(0);
295  Value leadingZeros = after->getArgument(1);
296 
297  auto one = rewriter.create<arith::ConstantOp>(
298  loc, IntegerAttr::get(elementTy, 1));
299  auto shifted =
300  rewriter.create<arith::ShRUIOp>(loc, resultTypes, input, one);
301  auto leadingZerosMinusOne =
302  rewriter.create<arith::SubIOp>(loc, resultTypes, leadingZeros, one);
303 
304  rewriter.create<scf::YieldOp>(
305  loc,
306  ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
307  }
308 
309  rewriter.setInsertionPointAfter(whileOp);
310  return whileOp->getResult(1);
311  }
312 
313  // tosa::LogicalAnd
314  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
315  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
316 
317  // tosa::LogicalNot
318  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
319  auto one = rewriter.create<arith::ConstantOp>(
320  loc, rewriter.getIntegerAttr(elementTy, 1));
321  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
322  }
323 
324  // tosa::LogicalOr
325  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
326  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
327 
328  // tosa::LogicalXor
329  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
330  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
331 
332  // tosa::PowOp
333  if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
334  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
335 
336  // tosa::RsqrtOp
337  if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
338  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
339 
340  // tosa::LogOp
341  if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
342  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
343 
344  // tosa::ExpOp
345  if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
346  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
347 
348  // tosa::TanhOp
349  if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
350  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
351 
352  // tosa::GreaterOp
353  if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
354  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
355  args[0], args[1]);
356 
357  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
358  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
359  args[0], args[1]);
360 
361  // tosa::GreaterEqualOp
362  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
363  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
364  args[0], args[1]);
365 
366  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
367  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
368  args[0], args[1]);
369 
370  // tosa::EqualOp
371  if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
372  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
373  args[0], args[1]);
374 
375  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
376  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
377  args[0], args[1]);
378 
379  // tosa::SelectOp
380  if (isa<tosa::SelectOp>(op)) {
381  elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
382  if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
383  return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
384  }
385 
386  // tosa::MaximumOp
387  if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
388  auto predicate = rewriter.create<arith::CmpFOp>(
389  loc, arith::CmpFPredicate::OGT, args[0], args[1]);
390  return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
391  }
392 
393  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
394  auto predicate = rewriter.create<arith::CmpIOp>(
395  loc, arith::CmpIPredicate::sgt, args[0], args[1]);
396  return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
397  }
398 
399  // tosa::MinimumOp
400  if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
401  auto predicate = rewriter.create<arith::CmpFOp>(
402  loc, arith::CmpFPredicate::OLT, args[0], args[1]);
403  return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
404  }
405 
406  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
407  auto predicate = rewriter.create<arith::CmpIOp>(
408  loc, arith::CmpIPredicate::slt, args[0], args[1]);
409  return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
410  }
411 
412  // tosa::CeilOp
413  if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
414  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
415 
416  // tosa::FloorOp
417  if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
418  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
419 
420  // tosa::ClampOp
421  if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
422  auto min = rewriter.create<arith::ConstantOp>(loc, elementTy,
423  op->getAttr("min_fp"));
424  auto max = rewriter.create<arith::ConstantOp>(loc, elementTy,
425  op->getAttr("max_fp"));
426  return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
427  arith::CmpFPredicate::OLT, rewriter);
428  }
429 
430  if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
431  auto intTy = elementTy.cast<IntegerType>();
432  int32_t min = static_cast<int32_t>(
433  op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
434  int32_t max = static_cast<int32_t>(
435  op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
436 
437  if (intTy.isUnsignedInteger()) {
438  min = std::max<int32_t>(min, 0);
439  max = std::min<int32_t>(
440  max,
441  APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
442  } else {
443  min = std::max<int32_t>(
444  min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
445  .getSExtValue());
446  max = std::min<int32_t>(
447  max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
448  .getSExtValue());
449  }
450 
451  auto minVal = rewriter.create<arith::ConstantIntOp>(
452  loc, min, intTy.getIntOrFloatBitWidth());
453  auto maxVal = rewriter.create<arith::ConstantIntOp>(
454  loc, max, intTy.getIntOrFloatBitWidth());
455  return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal,
456  arith::CmpIPredicate::slt, rewriter);
457  }
458 
459  // tosa::ReluNOp
460  if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
461  auto zero =
462  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
463  auto n = rewriter.create<arith::ConstantOp>(loc, elementTy,
464  op->getAttr("max_fp"));
465  return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
466  arith::CmpFPredicate::OLT, rewriter);
467  }
468 
469  if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
470  auto zero =
471  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
472  auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
473  rewriter);
474  return clampHelper<arith::CmpIOp>(loc, args[0], zero, n,
475  arith::CmpIPredicate::slt, rewriter);
476  }
477 
478  // tosa::SigmoidOp
479  if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
480  auto one =
481  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
482  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
483  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
484  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
485  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
486  }
487 
488  // tosa::CastOp
489  if (isa<tosa::CastOp>(op)) {
490  Type srcTy = elementTy;
491  Type dstTy = resultTypes.front();
492  bool bitExtend =
494 
495  if (srcTy == dstTy)
496  return args.front();
497 
498  if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
499  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args, mlir::None);
500 
501  if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
502  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
503  mlir::None);
504 
505  // 1-bit integers need to be treated as signless.
506  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
507  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
508  mlir::None);
509 
510  if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
511  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
512  mlir::None);
513 
514  // Unsigned integers need an unrealized cast so that they can be passed
515  // to UIToFP.
516  if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) {
517  auto unrealizedCast =
518  rewriter
519  .create<UnrealizedConversionCastOp>(
520  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
521  args[0])
522  .getResult(0);
523  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
524  unrealizedCast);
525  }
526 
527  // All other si-to-fp conversions should be handled by SIToFP.
528  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
529  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
530  mlir::None);
531 
532  // Casting to boolean, floats need to only be checked as not-equal to zero.
533  if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
534  Value zero = rewriter.create<arith::ConstantOp>(
535  loc, rewriter.getFloatAttr(srcTy, 0.0));
536  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
537  args.front(), zero);
538  }
539 
540  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
541  auto zero = rewriter.create<arith::ConstantOp>(
542  loc, rewriter.getF32FloatAttr(0.0f));
543  auto half = rewriter.create<arith::ConstantOp>(
544  loc, rewriter.getF32FloatAttr(0.5f));
545 
546  auto intMin = rewriter.create<arith::ConstantOp>(
547  loc, rewriter.getF32FloatAttr(
548  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
549  .getSExtValue()));
550 
551  auto intMax = rewriter.create<arith::ConstantOp>(
552  loc, rewriter.getF32FloatAttr(
553  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
554  .getSExtValue()));
555 
556  auto added = rewriter.create<arith::AddFOp>(loc, args[0], half);
557  auto subbed = rewriter.create<arith::SubFOp>(loc, args[0], half);
558  auto negative = rewriter.create<arith::CmpFOp>(
559  loc, arith::CmpFPredicate::OLT, args[0], zero);
560  auto rounded =
561  rewriter.create<mlir::SelectOp>(loc, negative, subbed, added);
562 
563  auto clamped = clampHelper<arith::CmpFOp>(
564  loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter);
565 
566  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
567  }
568 
569  // Casting to boolean, integers need to only be checked as not-equal to
570  // zero.
571  if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
572  Value zero = rewriter.create<arith::ConstantIntOp>(
573  loc, 0, srcTy.getIntOrFloatBitWidth());
574  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
575  args.front(), zero);
576  }
577 
578  if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
579  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
580  mlir::None);
581 
582  if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
583  auto intMin = rewriter.create<arith::ConstantIntOp>(
584  loc,
585  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
586  .getSExtValue(),
587  srcTy.getIntOrFloatBitWidth());
588 
589  auto intMax = rewriter.create<arith::ConstantIntOp>(
590  loc,
591  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
592  .getSExtValue(),
593  srcTy.getIntOrFloatBitWidth());
594 
595  auto clamped = clampHelper<arith::CmpIOp>(
596  loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter);
597  return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped);
598  }
599  }
600 
601  (void)rewriter.notifyMatchFailure(
602  op, "unhandled op for linalg body calculation for elementwise op");
603  return nullptr;
604 }
605 
606 static LogicalResult
608  PatternRewriter &rewriter) {
609  auto loc = operation->getLoc();
610 
611  assert(operation->getNumResults() == 1 &&
612  "All TOSA elementwise ops should only return a single result.");
613 
614  auto results = operation->getResults();
615  auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
616 
617  if (!resultTy)
618  return rewriter.notifyMatchFailure(operation,
619  "All results must be a shaped type");
620 
621  unsigned rank = resultTy.getRank();
622 
623  // Construct the indexing maps needed for linalg.generic ops.
624  SmallVector<Type> bodyArgTypes;
625 
626  for (Value in : operation->getOperands())
627  bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
628 
629  SmallVector<Type> opResultTypes;
630  SmallVector<Value> initTensors;
631 
632  SmallVector<Value> dynDims;
633  dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
634 
635  for (auto arg : operation->getOperands()) {
636  auto operandTy = arg.getType().cast<ShapedType>();
637  for (int i = 0; i < operandTy.getRank(); i++) {
638  if (operandTy.isDynamicDim(i) && !dynDims[i])
639  dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
640  }
641  }
642 
643  SmallVector<Value> filteredDims = condenseValues(dynDims);
644 
645  for (auto result : results) {
646  auto resultTy = result.getType().template cast<ShapedType>();
647  initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
648  loc, filteredDims, resultTy.getShape(), resultTy.getElementType()));
649  opResultTypes.push_back(result.getType());
650  }
651 
652  auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
653  initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
654 
655  SmallVector<Value, 2> operands;
656  SmallVector<AffineMap, 2> indexingMaps;
657  indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
658 
659  // Input indexing maps may be broadcasted.
660  for (Value operand : operation->getOperands()) {
661  ShapedType type = operand.getType().cast<ShapedType>();
662 
663  if (type.getShape() == resultTy.getShape()) {
664  operands.push_back(operand);
665  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
666  continue;
667  }
668 
669  SmallVector<int64_t, 5> newShape;
670  SmallVector<AffineExpr, 4> affineExprs;
671  newShape.reserve(type.getRank());
672  for (const auto &it : llvm::enumerate(type.getShape())) {
673  if (it.value() == resultTy.getDimSize(it.index())) {
674  newShape.push_back(it.value());
675  affineExprs.push_back(
676  mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
677  }
678  }
679 
680  if (newShape.size() != rank) {
681  operand = rewriter.create<tosa::ReshapeOp>(
682  loc, RankedTensorType::get(newShape, type.getElementType()), operand,
683  rewriter.getI64ArrayAttr(newShape));
684  }
685 
686  operands.push_back(operand);
687  indexingMaps.push_back(AffineMap::get(
688  /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
689  rewriter.getContext()));
690  }
691 
692  indexingMaps.append(operation->getNumResults(),
693  rewriter.getMultiDimIdentityMap(rank));
694 
695  bool didEncounterError = false;
696  auto linalgOp = rewriter.create<linalg::GenericOp>(
697  loc, opResultTypes, operands, initTensors, indexingMaps,
699  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
701  operation, blockArgs.take_front(operation->getNumOperands()),
702  bodyResultTypes, rewriter);
703  if (!opResult) {
704  didEncounterError = true;
705  return;
706  }
707  nestedBuilder.create<linalg::YieldOp>(loc, opResult);
708  });
709 
710  if (didEncounterError)
711  return failure();
712 
713  rewriter.replaceOp(operation, linalgOp->getResults());
714  return success();
715 }
716 
717 // Returns the constant initial value for a given reduction operation. The
718 // attribute type varies depending on the element type required.
720  PatternRewriter &rewriter) {
721  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
722  return rewriter.getFloatAttr(elementTy, 0.0);
723 
724  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
725  return rewriter.getIntegerAttr(elementTy, 0);
726 
727  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
728  return rewriter.getFloatAttr(elementTy, 1.0);
729 
730  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
731  return rewriter.getIntegerAttr(elementTy, 1);
732 
733  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
734  return rewriter.getFloatAttr(
735  elementTy, APFloat::getLargest(
736  elementTy.cast<FloatType>().getFloatSemantics(), false));
737 
738  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
739  return rewriter.getIntegerAttr(
740  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
741 
742  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
743  return rewriter.getFloatAttr(
744  elementTy, APFloat::getLargest(
745  elementTy.cast<FloatType>().getFloatSemantics(), true));
746 
747  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
748  return rewriter.getIntegerAttr(
749  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
750 
751  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
752  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
753 
754  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
755  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
756 
757  if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
758  return rewriter.getFloatAttr(
759  elementTy, APFloat::getLargest(
760  elementTy.cast<FloatType>().getFloatSemantics(), true));
761 
762  if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
763  return rewriter.getIntegerAttr(
764  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
765 
766  return {};
767 }
768 
769 // Creates the body calculation for a reduction. The operations vary depending
770 // on the input type.
772  ValueRange args,
773  Type elementTy,
774  PatternRewriter &rewriter) {
775  Location loc = op->getLoc();
776  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
777  return rewriter.create<arith::AddFOp>(loc, args);
778  }
779 
780  if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
781  return rewriter.create<arith::AddIOp>(loc, args);
782  }
783 
784  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
785  return rewriter.create<arith::MulFOp>(loc, args);
786  }
787 
788  if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
789  return rewriter.create<arith::MulIOp>(loc, args);
790  }
791 
792  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
793  auto predicate = rewriter.create<arith::CmpFOp>(
794  loc, arith::CmpFPredicate::OLT, args[0], args[1]);
795  return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
796  }
797 
798  if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
799  auto predicate = rewriter.create<arith::CmpIOp>(
800  loc, arith::CmpIPredicate::slt, args[0], args[1]);
801  return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
802  }
803 
804  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
805  auto predicate = rewriter.create<arith::CmpFOp>(
806  loc, arith::CmpFPredicate::OGT, args[0], args[1]);
807  return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
808  }
809 
810  if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
811  auto predicate = rewriter.create<arith::CmpIOp>(
812  loc, arith::CmpIPredicate::sgt, args[0], args[1]);
813  return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
814  }
815 
816  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
817  return rewriter.create<arith::AndIOp>(loc, args);
818 
819  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
820  return rewriter.create<arith::OrIOp>(loc, args);
821 
822  return {};
823 }
824 
825 // Performs the match and rewrite for reduction operations. This includes
826 // declaring a correctly sized initial value, and the linalg.generic operation
827 // that reduces across the specified axis.
829  PatternRewriter &rewriter) {
830  auto loc = op->getLoc();
831  auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
832  auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
833  auto elementTy = resultTy.getElementType();
834  Value input = op->getOperand(0);
835 
836  llvm::SmallVector<int64_t> reduceShape;
837  SmallVector<Value> dynDims;
838  for (unsigned i = 0; i < inputTy.getRank(); i++) {
839  if (axis != i) {
840  reduceShape.push_back(inputTy.getDimSize(i));
841  if (inputTy.isDynamicDim(i))
842  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
843  }
844  }
845 
846  Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
847 
848  // First fill the output buffer with the init value.
849  auto initTensor = rewriter
850  .create<linalg::InitTensorOp>(loc, dynDims, reduceShape,
851  resultTy.getElementType())
852  .result();
853 
854  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
855  if (!fillValueAttr)
856  return rewriter.notifyMatchFailure(
857  op, "No initial value found for reduction operation");
858 
859  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
860  auto filledTensor =
861  rewriter.create<linalg::FillOp>(loc, fillValue, initTensor).result();
862 
865  SmallVector<StringRef, 4> iteratorTypes;
866  for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
867  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
868 
869  iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
871  if (axis != i)
872  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
873  }
874 
875  bool didEncounterError = false;
876  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
877  auto linalgOp = rewriter.create<linalg::GenericOp>(
878  loc, reduceTy, input, filledTensor, maps, iteratorTypes,
879  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
881  op, blockArgs, elementTy, rewriter);
882  if (result)
883  didEncounterError = true;
884 
885  nestedBuilder.create<linalg::YieldOp>(loc, result);
886  });
887 
888  if (!didEncounterError)
889  return failure();
890 
891  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
892  linalgOp.getResults());
893  return success();
894 }
895 
897  ArrayRef<int64_t> rhsShape,
898  SmallVector<int64_t> &intermediateShape,
899  bool isDynamic) {
900  if (isDynamic) {
901  // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
902  intermediateShape = {-1};
903  return true;
904  }
905 
906  if (lhsShape.empty() || rhsShape.empty()) {
907  intermediateShape = {};
908  return true;
909  }
910 
911  unsigned currLhsDim = 0, currRhsDim = 0;
912  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
913  int64_t rhsSize = rhsShape[currRhsDim];
914  int64_t lhsSize = lhsShape[currLhsDim];
915  while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
916  currRhsDim < rhsShape.size()) {
917  if (lhsSize < rhsSize) {
918  currLhsDim++;
919  lhsSize *= lhsShape[currLhsDim];
920  } else {
921  currRhsDim++;
922  rhsSize *= rhsShape[currRhsDim];
923  }
924  }
925  if (lhsSize == rhsSize) {
926  intermediateShape.push_back(lhsSize);
927  }
928  currRhsDim++;
929  currLhsDim++;
930  }
931 
932  // If the iterators didn't reach the end and their leftover dimensions are not
933  // equal to 1 an intermediate shape was not found.
934  while (currLhsDim < lhsShape.size()) {
935  if (lhsShape[currLhsDim++] != 1) {
936  return false;
937  }
938  }
939 
940  while (currRhsDim < rhsShape.size()) {
941  if (rhsShape[currRhsDim++] != 1) {
942  return false;
943  }
944  }
945 
946  return true;
947 }
948 
950  PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
951  ArrayRef<int64_t> dstShape,
952  SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
953 
954  // If the shape is dynamic, create a map for collapsing into one dimension.
955  if (isDynamic) {
957  for (int i = 0, s = srcShape.size(); i < s; ++i)
958  exprs.push_back(rewriter.getAffineDimExpr(i));
959  reassociationMap = {exprs};
960  return true;
961  }
962 
963  if (dstShape.empty()) {
964  reassociationMap = {};
965  return true;
966  }
967 
968  reassociationMap.resize(dstShape.size());
969  unsigned currSrcDim = 0, currDstDim = 0;
970  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
971  int64_t dstSize = dstShape[currDstDim];
972  int64_t srcSize = srcShape[currSrcDim];
973  while (srcSize < dstSize && currSrcDim < srcShape.size()) {
974  reassociationMap[currDstDim].push_back(
975  rewriter.getAffineDimExpr(currSrcDim++));
976  srcSize *= srcShape[currSrcDim];
977  }
978  if (srcSize == dstSize) {
979  reassociationMap[currDstDim].push_back(
980  rewriter.getAffineDimExpr(currSrcDim++));
981  // If the next dim in collapsedShape is not 1, treat subsequent dims in
982  // expandedShape which are 1 to be collapsed.
983  if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
984  while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
985  reassociationMap[currDstDim].push_back(
986  rewriter.getAffineDimExpr(currSrcDim++));
987  }
988  }
989  }
990  currDstDim++;
991  }
992 
993  // If both iterators didn't reach the end, we have leftover dimentions which
994  // implies that we have a mismatch in shape.
995  return !(currSrcDim != srcShape.size() || currDstDim != dstShape.size());
996 }
997 
998 namespace {
999 
1000 template <typename SrcOp>
1001 class PointwiseConverter : public OpRewritePattern<SrcOp> {
1002 public:
1004 
1005  LogicalResult matchAndRewrite(SrcOp op,
1006  PatternRewriter &rewriter) const final {
1007  return elementwiseMatchAndRewriteHelper(op, rewriter);
1008  }
1009 };
1010 
1011 class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
1012 public:
1014 
1016  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1017  ConversionPatternRewriter &rewriter) const final {
1018  ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
1019  ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1020  bool isDynamic = !operandTy.hasStaticShape();
1021 
1022  if (isDynamic && resultTy.getRank() != 1) {
1023  return rewriter.notifyMatchFailure(
1024  reshape, "Cannot collapse dynamic dims to more than one dimension");
1025  }
1026 
1027  if (operandTy == resultTy) {
1028  rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1029  return success();
1030  }
1031 
1032  SmallVector<ReassociationExprs, 4> reassociationMap;
1033  if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
1034  resultTy.getShape(),
1035  reassociationMap, isDynamic)) {
1036  return rewriter.notifyMatchFailure(
1037  reshape,
1038  "tosa.reshape Attempting to collapse into an incompatible shape");
1039  }
1040 
1041  SmallVector<int64_t> intermediateShape;
1042  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
1043  intermediateShape, isDynamic)) {
1044  return rewriter.notifyMatchFailure(
1045  reshape, "tosa.reshape Cannot collapse into given shape");
1046  }
1047 
1048  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1049  reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1050  return success();
1051  }
1052 };
1053 
1054 class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
1055 public:
1057 
1059  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1060  ConversionPatternRewriter &rewriter) const final {
1061  ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
1062  ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1063  bool isDynamic = !operandTy.hasStaticShape();
1064 
1065  if (operandTy == resultTy) {
1066  rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1067  return success();
1068  }
1069 
1070  if (isDynamic && operandTy.getRank() != 1) {
1071  return rewriter.notifyMatchFailure(
1072  reshape, "Cannot expand dynamic dims from more than one dimension");
1073  }
1074 
1075  SmallVector<ReassociationExprs, 4> reassociationMap;
1076  if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
1077  operandTy.getShape(),
1078  reassociationMap, isDynamic)) {
1079  return rewriter.notifyMatchFailure(
1080  reshape,
1081  "tosa.reshape Attempting to expand into an incompatible shape");
1082  }
1083 
1084  SmallVector<int64_t> intermediateShape;
1085  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
1086  intermediateShape, isDynamic) ||
1087  intermediateShape != operandTy.getShape()) {
1088  return rewriter.notifyMatchFailure(
1089  reshape, "tosa.reshape Cannot expand into given shape");
1090  }
1091  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1092  reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1093  return success();
1094  }
1095 };
1096 
1097 class ReshapeConverterCollapseExpand
1098  : public OpConversionPattern<tosa::ReshapeOp> {
1099 public:
1101 
1103  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1104  ConversionPatternRewriter &rewriter) const final {
1105  ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
1106  ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1107  bool isDynamic = !operandTy.hasStaticShape();
1108 
1109  if (operandTy == resultTy) {
1110  rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1111  return success();
1112  }
1113 
1114  SmallVector<int64_t> intermediateShape;
1115  if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
1116  intermediateShape, isDynamic)) {
1117  return rewriter.notifyMatchFailure(
1118  reshape, "tosa.reshape Cannot identify an intermediate shape between "
1119  "the given two shapes");
1120  }
1121 
1122  Value collapse = rewriter.create<tosa::ReshapeOp>(
1123  reshape.getLoc(),
1124  RankedTensorType::get(intermediateShape,
1125  reshape.getType().getElementType()),
1126  adaptor.input1());
1127  Value expand =
1128  rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
1129  rewriter.replaceOp(reshape, expand);
1130 
1131  return success();
1132  }
1133 };
1134 
1135 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1136 public:
1138 
1139  LogicalResult matchAndRewrite(tosa::TransposeOp op,
1140  PatternRewriter &rewriter) const final {
1141  DenseIntElementsAttr perms;
1142  if (!matchPattern(op.perms(), m_Constant(&perms))) {
1143  return failure();
1144  }
1145 
1146  auto loc = op.getLoc();
1147  auto input = op->getOperand(0);
1148  auto resultTy = op.getType().cast<ShapedType>();
1149 
1150  SmallVector<Value> dynDims;
1151  dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1152 
1153  SmallVector<AffineExpr, 2> inputExprs;
1154  inputExprs.resize(resultTy.getRank());
1155  auto operandTy = input.getType().cast<ShapedType>();
1156  for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
1157  auto index = permutation.index();
1158  auto value = permutation.value().getZExtValue();
1159  if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1160  dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
1161  }
1162  inputExprs[value] = rewriter.getAffineDimExpr(index);
1163  }
1164 
1165  SmallVector<Value> filteredDims = condenseValues(dynDims);
1166 
1167  auto initTensor = rewriter.create<linalg::InitTensorOp>(
1168  loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
1169 
1170  SmallVector<AffineMap, 2> affineMaps = {
1171  AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1172  rewriter.getContext()),
1173  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1174 
1175  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1176  op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
1177  getNParallelLoopsAttrs(resultTy.getRank()),
1178  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1179  nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
1180  });
1181  return success();
1182  }
1183 };
1184 
1185 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1186 public:
1188 
1189  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1190  PatternRewriter &rewriter) const final {
1191  auto loc = op.getLoc();
1192  auto input = op.input();
1193  auto inputTy = op.input().getType().cast<ShapedType>();
1194  auto outputTy = op.output().getType().cast<ShapedType>();
1195  unsigned rank = inputTy.getRank();
1196 
1197  // This is an illegal configuration. terminate and log an error
1198  if (op.double_round() && !op.scale32())
1199  return rewriter.notifyMatchFailure(
1200  op, "tosa.rescale requires scale32 for double_round to be true");
1201 
1202  auto dynamicDimsOr =
1203  checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
1204  if (!dynamicDimsOr.hasValue())
1205  return failure();
1206  SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
1207 
1208  // The shift and multiplier values.
1209  SmallVector<int32_t> multiplierValues;
1210  getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
1211 
1212  SmallVector<int8_t> shiftValues;
1213  getValuesFromIntArrayAttribute(op.shift(), shiftValues);
1214 
1215  // If we shift by more than the bitwidth, this just sets to 0.
1216  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1217  if (shiftValues[i] > 63) {
1218  shiftValues[i] = 0;
1219  multiplierValues[i] = 0;
1220  }
1221  }
1222 
1223  // Double round only occurs if shift is greater than 31, check that this
1224  // is ever true.
1225  bool doubleRound =
1226  op.double_round() &&
1227  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1228 
1229  SmallVector<AffineMap> indexingMaps = {
1230  rewriter.getMultiDimIdentityMap(rank)};
1231  SmallVector<Value, 4> genericInputs = {input};
1232 
1233  // If we are rescaling per-channel then we need to store the multiplier
1234  // values in a buffer.
1235  Value multiplierConstant;
1236  int64_t multiplierArg = 0;
1237  if (multiplierValues.size() == 1) {
1238  multiplierConstant = rewriter.create<arith::ConstantOp>(
1239  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1240  } else {
1241  SmallVector<AffineExpr, 2> multiplierExprs{
1242  rewriter.getAffineDimExpr(rank - 1)};
1243  auto multiplierType =
1244  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1245  rewriter.getI32Type());
1246  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1247  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1248 
1249  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1250  /*symbolCount=*/0, multiplierExprs,
1251  rewriter.getContext()));
1252 
1253  multiplierArg = indexingMaps.size() - 1;
1254  }
1255 
1256  // If we are rescaling per-channel then we need to store the shift
1257  // values in a buffer.
1258  Value shiftConstant;
1259  int64_t shiftArg = 0;
1260  if (shiftValues.size() == 1) {
1261  shiftConstant = rewriter.create<arith::ConstantOp>(
1262  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1263  } else {
1264  SmallVector<AffineExpr, 2> shiftExprs = {
1265  rewriter.getAffineDimExpr(rank - 1)};
1266  auto shiftType =
1267  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1268  rewriter.getIntegerType(8));
1269  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1270  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1271  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1272  /*symbolCount=*/0, shiftExprs,
1273  rewriter.getContext()));
1274  shiftArg = indexingMaps.size() - 1;
1275  }
1276 
1277  // Indexing maps for output values.
1278  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1279 
1280  // Construct the indexing maps needed for linalg.generic ops.
1281  Value initTensor = rewriter.create<linalg::InitTensorOp>(
1282  loc, dynamicDims, outputTy.getShape(), outputTy.getElementType());
1283 
1284  auto linalgOp = rewriter.create<linalg::GenericOp>(
1285  loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
1286  getNParallelLoopsAttrs(rank),
1287  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1288  ValueRange blockArgs) {
1289  Value value = blockArgs[0];
1290  Type valueTy = value.getType();
1291 
1292  // For now we do all of our math in 64-bit. This is not optimal but
1293  // should be correct for now, consider computing correct bit depth
1294  // later.
1295  int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1296 
1297  auto inputZp = createConstFromIntAttribute<int32_t>(
1298  op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1299  nestedBuilder);
1300  auto outputZp = createConstFromIntAttribute<int32_t>(
1301  op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1302 
1303  Value multiplier = multiplierConstant ? multiplierConstant
1304  : blockArgs[multiplierArg];
1305  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1306 
1307  if (valueTy.getIntOrFloatBitWidth() < 32) {
1308  if (valueTy.isUnsignedInteger()) {
1309  value = nestedBuilder
1310  .create<UnrealizedConversionCastOp>(
1311  nestedLoc,
1312  nestedBuilder.getIntegerType(
1313  valueTy.getIntOrFloatBitWidth()),
1314  value)
1315  .getResult(0);
1316  value = nestedBuilder.create<arith::ExtUIOp>(
1317  nestedLoc, nestedBuilder.getI32Type(), value);
1318  } else {
1319  value = nestedBuilder.create<arith::ExtSIOp>(
1320  nestedLoc, nestedBuilder.getI32Type(), value);
1321  }
1322  }
1323 
1324  value =
1325  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1326 
1327  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1328  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1329  nestedBuilder.getBoolAttr(doubleRound));
1330 
1331  // Move to the new zero-point.
1332  value =
1333  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1334 
1335  // Saturate to the output size.
1336  IntegerType outIntType =
1337  blockArgs.back().getType().cast<IntegerType>();
1338  unsigned outBitWidth = outIntType.getWidth();
1339 
1340  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1341  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1342 
1343  // Unsigned integers have a difference output value.
1344  if (outIntType.isUnsignedInteger()) {
1345  intMin = 0;
1346  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1347  }
1348 
1349  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1350  loc, nestedBuilder.getI32IntegerAttr(intMin));
1351  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1352  loc, nestedBuilder.getI32IntegerAttr(intMax));
1353 
1354  value = clampHelper<arith::CmpIOp>(
1355  nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt,
1356  nestedBuilder);
1357 
1358  if (outIntType.getWidth() < 32) {
1359  value = nestedBuilder.create<arith::TruncIOp>(
1360  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1361  value);
1362 
1363  if (outIntType.isUnsignedInteger()) {
1364  value = nestedBuilder
1365  .create<UnrealizedConversionCastOp>(nestedLoc,
1366  outIntType, value)
1367  .getResult(0);
1368  }
1369  }
1370 
1371  nestedBuilder.create<linalg::YieldOp>(loc, value);
1372  });
1373 
1374  rewriter.replaceOp(op, linalgOp->getResults());
1375  return success();
1376  }
1377 };
1378 
1379 class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1380 public:
1382 
1383  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1384  PatternRewriter &rewriter) const final {
1385  Location loc = op.getLoc();
1386  auto input = op.input();
1387  auto inputTy = input.getType().cast<ShapedType>();
1388  auto resultTy = op.getType().cast<ShapedType>();
1389  auto resultElementTy = resultTy.getElementType();
1390 
1391  auto imageH = inputTy.getShape()[1];
1392  auto imageW = inputTy.getShape()[2];
1393 
1394  auto dynamicDimsOr =
1395  checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
1396  if (!dynamicDimsOr.hasValue())
1397  return failure();
1398  SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
1399 
1400  if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
1401  return failure();
1402 
1403  auto initTensor = rewriter.create<linalg::InitTensorOp>(
1404  loc, dynamicDims, resultTy.getShape(), resultElementTy);
1405 
1406  SmallVector<AffineMap, 2> affineMaps = {
1407  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1408 
1409  auto genericOp = rewriter.create<linalg::GenericOp>(
1410  loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
1411  getNParallelLoopsAttrs(resultTy.getRank()));
1412  rewriter.replaceOp(op, genericOp.getResult(0));
1413 
1414  OpBuilder::InsertionGuard regionGuard(rewriter);
1415  rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
1416  TypeRange({resultElementTy}), loc);
1417  Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
1418  Value y = rewriter.create<linalg::IndexOp>(loc, 1);
1419  Value x = rewriter.create<linalg::IndexOp>(loc, 2);
1420  Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
1421 
1422  auto hwMin =
1423  rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1424  auto hMax = rewriter.create<arith::ConstantOp>(
1425  loc, rewriter.getI32IntegerAttr(imageH - 1));
1426  auto wMax = rewriter.create<arith::ConstantOp>(
1427  loc, rewriter.getI32IntegerAttr(imageW - 1));
1428 
1429  Value inY =
1430  rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), y);
1431  Value inX =
1432  rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
1433 
1434  int32_t shift = op.shift();
1435  bool floatingPointMode = shift == 0;
1436 
1437  Value yStride, xStride, yOffset, xOffset;
1438  if (floatingPointMode) {
1439  yStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[0]);
1440  xStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[1]);
1441  yOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[0]);
1442  xOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[1]);
1443  } else {
1444  SmallVector<int32_t> stride, offset;
1445  getValuesFromIntArrayAttribute(op.stride(), stride);
1446  getValuesFromIntArrayAttribute(op.offset(), offset);
1447 
1448  yStride = rewriter.create<arith::ConstantOp>(
1449  loc, rewriter.getI32IntegerAttr(stride[0]));
1450  xStride = rewriter.create<arith::ConstantOp>(
1451  loc, rewriter.getI32IntegerAttr(stride[1]));
1452  yOffset = rewriter.create<arith::ConstantOp>(
1453  loc, rewriter.getI32IntegerAttr(offset[0]));
1454  xOffset = rewriter.create<arith::ConstantOp>(
1455  loc, rewriter.getI32IntegerAttr(offset[1]));
1456  }
1457 
1458  // Compute the the integer index and partial offset.
1459  // x = x * stride + offset;
1460  // ix = floor(x)
1461  // dx = x - ix
1462  Value ix, iy, dx, dy;
1463  if (floatingPointMode) {
1464  Value y =
1465  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY);
1466  Value x =
1467  rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX);
1468 
1469  y = rewriter.create<arith::MulFOp>(loc, y, yStride);
1470  x = rewriter.create<arith::MulFOp>(loc, x, xStride);
1471 
1472  y = rewriter.create<arith::AddFOp>(loc, y, yOffset);
1473  x = rewriter.create<arith::AddFOp>(loc, x, xOffset);
1474 
1475  iy = rewriter.create<math::FloorOp>(loc, y);
1476  ix = rewriter.create<math::FloorOp>(loc, x);
1477 
1478  dy = rewriter.create<arith::SubFOp>(loc, y, iy);
1479  dx = rewriter.create<arith::SubFOp>(loc, x, ix);
1480 
1481  iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy);
1482  ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix);
1483  } else {
1484  Value shiftVal = rewriter.create<arith::ConstantOp>(
1485  loc, rewriter.getI32IntegerAttr(shift));
1486 
1487  Value y = rewriter.create<arith::MulIOp>(loc, inY, yStride);
1488  Value x = rewriter.create<arith::MulIOp>(loc, inX, xStride);
1489 
1490  y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
1491  x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
1492 
1493  iy = rewriter.create<arith::ShRSIOp>(loc, y, shiftVal);
1494  ix = rewriter.create<arith::ShRSIOp>(loc, x, shiftVal);
1495 
1496  Value yTrunc = rewriter.create<arith::ShLIOp>(loc, iy, shiftVal);
1497  Value xTrunc = rewriter.create<arith::ShLIOp>(loc, ix, shiftVal);
1498 
1499  dy = rewriter.create<arith::SubIOp>(loc, y, yTrunc);
1500  dx = rewriter.create<arith::SubIOp>(loc, x, xTrunc);
1501  }
1502 
1503  if (op.mode() == "NEAREST_NEIGHBOR") {
1504  Value yPred, xPred;
1505  // Round the index position towards the closest pixel location.
1506  if (floatingPointMode) {
1507  auto halfVal = rewriter.create<arith::ConstantOp>(
1508  loc, rewriter.getF32FloatAttr(0.5f));
1509  yPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1510  dy, halfVal);
1511  xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1512  dx, halfVal);
1513  } else {
1514  auto halfVal = rewriter.create<arith::ConstantOp>(
1515  loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
1516  yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1517  dy, halfVal);
1518  xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1519  dx, halfVal);
1520  }
1521 
1522  auto zeroVal = rewriter.create<arith::ConstantOp>(
1523  loc, rewriter.getI32IntegerAttr(0));
1524  auto oneVal = rewriter.create<arith::ConstantOp>(
1525  loc, rewriter.getI32IntegerAttr(1));
1526 
1527  auto yOffset =
1528  rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal);
1529  auto xOffset =
1530  rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal);
1531 
1532  iy = rewriter.create<arith::AddIOp>(loc, iy, yOffset);
1533  ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);
1534 
1535  // Clamp the to be within the bounds of the input image.
1536 
1537  iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
1538  arith::CmpIPredicate::slt, rewriter);
1539  ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
1540  arith::CmpIPredicate::slt, rewriter);
1541 
1542  // Read the value from the input array.
1543  iy =
1544  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), iy);
1545  ix =
1546  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), ix);
1547 
1548  Value result = rewriter.create<tensor::ExtractOp>(
1549  loc, input, ValueRange{batch, iy, ix, channel});
1550 
1551  rewriter.create<linalg::YieldOp>(loc, result);
1552 
1553  return success();
1554  }
1555 
1556  if (op.mode() == "BILINEAR") {
1557  Value y0 = iy;
1558  Value x0 = ix;
1559 
1560  auto oneVal = rewriter.create<arith::ConstantOp>(
1561  loc, rewriter.getI32IntegerAttr(1));
1562  Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
1563  Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
1564 
1565  y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
1566  arith::CmpIPredicate::slt, rewriter);
1567  y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
1568  arith::CmpIPredicate::slt, rewriter);
1569 
1570  x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
1571  arith::CmpIPredicate::slt, rewriter);
1572  x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
1573  arith::CmpIPredicate::slt, rewriter);
1574 
1575  y0 =
1576  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);
1577  y1 =
1578  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y1);
1579  x0 =
1580  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x0);
1581  x1 =
1582  rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x1);
1583 
1584  Value y0x0 = rewriter.create<tensor::ExtractOp>(
1585  loc, input, ValueRange{batch, y0, x0, channel});
1586  Value y0x1 = rewriter.create<tensor::ExtractOp>(
1587  loc, input, ValueRange{batch, y0, x1, channel});
1588  Value y1x0 = rewriter.create<tensor::ExtractOp>(
1589  loc, input, ValueRange{batch, y1, x0, channel});
1590  Value y1x1 = rewriter.create<tensor::ExtractOp>(
1591  loc, input, ValueRange{batch, y1, x1, channel});
1592 
1593  if (floatingPointMode) {
1594  auto oneVal = rewriter.create<arith::ConstantOp>(
1595  loc, rewriter.getF32FloatAttr(1.f));
1596  Value rightPart = dx;
1597  Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
1598 
1599  y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
1600  y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
1601  Value topAcc = rewriter.create<arith::AddFOp>(loc, y0x0, y0x1);
1602 
1603  y1x0 = rewriter.create<arith::MulFOp>(loc, y1x0, leftPart);
1604  y1x1 = rewriter.create<arith::MulFOp>(loc, y1x1, rightPart);
1605  Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
1606 
1607  Value bottomPart = dy;
1608  Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
1609  topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
1610  bottomAcc = rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
1611  Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
1612 
1613  rewriter.create<linalg::YieldOp>(loc, result);
1614  return success();
1615  }
1616  y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
1617  y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
1618  y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
1619  y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
1620 
1621  if (resultElementTy.getIntOrFloatBitWidth() > 32) {
1622  dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
1623  dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
1624  }
1625 
1626  auto unitVal = rewriter.create<arith::ConstantOp>(
1627  loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
1628  Value rightPart = dx;
1629  Value leftPart = rewriter.create<arith::SubIOp>(loc, unitVal, dx);
1630 
1631  y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
1632  y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
1633  Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
1634 
1635  y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
1636  y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
1637  Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
1638 
1639  Value bottomPart = dy;
1640  Value topPart = rewriter.create<arith::SubIOp>(loc, unitVal, dy);
1641  topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
1642  bottomAcc = rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
1643  Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
1644 
1645  rewriter.create<linalg::YieldOp>(loc, result);
1646  return success();
1647  }
1648  return failure();
1649  }
1650 };
1651 
1652 // At the codegen level any identity operations should be removed. Any cases
1653 // where identity is load-bearing (e.g. cross device computation) should be
1654 // handled before lowering to codegen.
1655 template <typename SrcOp>
1656 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1657 public:
1659 
1660  LogicalResult matchAndRewrite(SrcOp op,
1661  PatternRewriter &rewriter) const final {
1662  rewriter.replaceOp(op, op.getOperation()->getOperands());
1663  return success();
1664  }
1665 };
1666 
1667 template <typename SrcOp>
1668 class ReduceConverter : public OpRewritePattern<SrcOp> {
1669 public:
1671 
1672  LogicalResult matchAndRewrite(SrcOp reduceOp,
1673  PatternRewriter &rewriter) const final {
1674  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
1675  }
1676 };
1677 
1678 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
1680 
1682  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
1683  ConversionPatternRewriter &rewriter) const override {
1684  auto resultType = op.getType().dyn_cast<RankedTensorType>();
1685  if (!resultType || !resultType.hasStaticShape()) {
1686  return rewriter.notifyMatchFailure(op,
1687  "expected static shaped tensor type");
1688  }
1689 
1690  Location loc = op.getLoc();
1691  int axis = op.axis();
1692  Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
1693  loc, rewriter.getIndexAttr(axis));
1694  int rank = resultType.getRank();
1695  SmallVector<Value, 3> offsets, sizes, strides;
1696  sizes.reserve(rank);
1697  strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
1698  offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
1699 
1700  for (int i = 0; i < rank; ++i) {
1701  sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
1702  loc, adaptor.getOperands()[0], i));
1703  }
1704 
1705  Value resultDimSize = sizes[axis];
1706  for (auto arg : adaptor.getOperands().drop_front()) {
1707  auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
1708  resultDimSize =
1709  rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
1710  }
1711  sizes[axis] = resultDimSize;
1712 
1713  Value init = rewriter.create<linalg::InitTensorOp>(
1714  loc, resultType.getShape(), resultType.getElementType());
1715 
1716  Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
1717  loc, rewriter.getZeroAttr(resultType.getElementType()));
1718  Value result =
1719  rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0);
1720 
1721  auto toOpFoldResult = [](Value v) -> OpFoldResult {
1722  auto op = v.getDefiningOp<arith::ConstantIndexOp>();
1723  if (!op)
1724  return v;
1725  return op.getValue();
1726  };
1727  for (auto arg : adaptor.getOperands()) {
1728  sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
1729  result = rewriter.createOrFold<tensor::InsertSliceOp>(
1730  loc, arg, result,
1731  llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
1732  llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
1733  llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
1734  offsets[axis] =
1735  rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
1736  }
1737  rewriter.replaceOp(op, result);
1738  return success();
1739  }
1740 };
1741 
1742 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1743 public:
1745 
1746  LogicalResult matchAndRewrite(tosa::ReverseOp op,
1747  PatternRewriter &rewriter) const final {
1748  auto loc = op.getLoc();
1749  Value input = op.input();
1750  auto inputTy = input.getType().template cast<ShapedType>();
1751  auto resultTy = op.getType().template cast<ShapedType>();
1752  auto axis = op.axis();
1753 
1754  SmallVector<Value> dynDims;
1755  for (int i = 0; i < inputTy.getRank(); i++) {
1756  if (inputTy.isDynamicDim(i)) {
1757  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1758  }
1759  }
1760 
1761  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1762 
1763  // First fill the output buffer with the init value.
1764  auto initTensor = rewriter
1765  .create<linalg::InitTensorOp>(
1766  loc, ArrayRef<Value>({dynDims}),
1767  inputTy.getShape(), inputTy.getElementType())
1768  .result();
1769  SmallVector<AffineMap, 2> affineMaps = {
1770  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1771 
1772  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1773  op, resultTy, ArrayRef<Value>({}), ValueRange{initTensor}, affineMaps,
1774  getNParallelLoopsAttrs(resultTy.getRank()),
1775  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1776  llvm::SmallVector<Value> indices;
1777  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1778  auto index =
1779  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1780  if (i == axis) {
1781  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1782  auto sizeMinusOne =
1783  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1784  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1785  index);
1786  }
1787 
1788  indices.push_back(index);
1789  }
1790 
1791  auto extract = nestedBuilder.create<tensor::ExtractOp>(
1792  nestedLoc, input, indices);
1793  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1794  extract.getResult());
1795  });
1796  return success();
1797  }
1798 };
1799 
1800 // This converter translate a tile operation to a reshape, broadcast, reshape.
1801 // The first reshape minimally expands each tiled dimension to include a
1802 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1803 // multiple.
1804 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1806 
1808  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1809  ConversionPatternRewriter &rewriter) const override {
1810  auto loc = op.getLoc();
1811  auto input = op.input1();
1812  auto inputTy = input.getType().cast<ShapedType>();
1813  auto inputShape = inputTy.getShape();
1814  auto resultTy = op.getType().cast<ShapedType>();
1815  auto elementTy = inputTy.getElementType();
1816  int64_t rank = inputTy.getRank();
1817 
1818  if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape())
1819  return failure();
1820 
1821  SmallVector<int64_t> multiples;
1822  getValuesFromIntArrayAttribute(op.multiples(), multiples);
1823 
1824  // Broadcast the newly added dimensions to their appropriate multiple.
1825  SmallVector<int64_t, 2> genericShape;
1826  for (int i = 0; i < rank; i++) {
1827  genericShape.push_back(multiples[i]);
1828  genericShape.push_back(inputShape[i]);
1829  }
1830 
1831  auto initTensor = rewriter.create<linalg::InitTensorOp>(
1832  op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
1833 
1834  // We needs to map the input shape to the non-broadcasted dimensions.
1835  SmallVector<AffineExpr, 4> dimExprs;
1836  dimExprs.reserve(rank);
1837  for (unsigned i = 0; i < rank; ++i)
1838  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1839 
1840  auto readAffineMap =
1841  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1842  rewriter.getContext());
1843 
1844  SmallVector<AffineMap, 2> affineMaps = {
1845  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1846 
1847  auto genericOp = rewriter.create<linalg::GenericOp>(
1848  loc, RankedTensorType::get(genericShape, elementTy), input,
1849  ValueRange{initTensor}, affineMaps,
1850  getNParallelLoopsAttrs(genericShape.size()),
1851  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1852  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1853  });
1854 
1855  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1856  op, resultTy, genericOp.getResult(0),
1857  rewriter.getI64ArrayAttr(resultTy.getShape()));
1858  return success();
1859  }
1860 };
1861 
1862 class PadConverter : public OpRewritePattern<tosa::PadOp> {
1863 public:
1865 
1866  LogicalResult matchAndRewrite(tosa::PadOp padOp,
1867  PatternRewriter &rewriter) const final {
1868  auto loc = padOp.getLoc();
1869  auto input = padOp.input1();
1870  auto padding = padOp.padding();
1871 
1872  ShapedType inputTy = input.getType().cast<ShapedType>();
1873  ShapedType paddingTy = padding.getType().cast<ShapedType>();
1874  Type elementTy = inputTy.getElementType();
1875  int64_t rank = inputTy.getRank();
1876 
1877  if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
1878  return rewriter.notifyMatchFailure(
1879  padOp,
1880  "Pad converter requires static shaped input / padding values.");
1881  }
1882 
1883  // Setup the default constantAttr.
1884 
1885  Value padConstant;
1886 
1887  if (padOp.pad_const()) {
1888  padConstant = rewriter.createOrFold<tensor::ExtractOp>(
1889  loc, padOp.pad_const(), ValueRange({}));
1890  } else {
1891  Attribute constantAttr;
1892  if (elementTy.isa<FloatType>())
1893  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
1894  else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
1895  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
1896  else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
1897  auto value = padOp.quantization_info().getValue().input_zp().getValue();
1898  constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
1899  }
1900  if (constantAttr)
1901  padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
1902  }
1903 
1904  if (!padConstant) {
1905  return rewriter.notifyMatchFailure(
1906  padOp, "tosa.pad was unable to determine the pad constant value.");
1907  }
1908 
1909  Value lowIndex =
1910  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
1911  Value highIndex =
1912  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
1913 
1914  SmallVector<OpFoldResult, 3> lowValues;
1915  SmallVector<OpFoldResult, 3> highValues;
1916 
1917  lowValues.reserve(rank);
1918  highValues.reserve(rank);
1919 
1920  for (int i = 0; i < rank; i++) {
1921  Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
1922  Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
1923  loc, padding, ValueRange({inputIndex, lowIndex}));
1924  Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
1925  loc, padding, ValueRange({inputIndex, highIndex}));
1926 
1927  lowVal = rewriter.createOrFold<arith::IndexCastOp>(
1928  loc, rewriter.getIndexType(), lowVal);
1929  highVal = rewriter.createOrFold<arith::IndexCastOp>(
1930  loc, rewriter.getIndexType(), highVal);
1931 
1932  lowValues.push_back(lowVal);
1933  highValues.push_back(highVal);
1934  }
1935 
1936  auto newPadOp = tensor::createPadScalarOp(
1937  padOp.getType(), input, padConstant, lowValues, highValues,
1938  /*nofold=*/false, loc, rewriter);
1939 
1940  rewriter.replaceOp(padOp, newPadOp.getResult());
1941  return success();
1942  }
1943 };
1944 
1945 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1946 // op, producing two output buffers.
1947 //
1948 // The first output buffer contains the index of the found maximum value. It is
1949 // initialized to 0 and is resulting integer type.
1950 //
1951 // The second output buffer contains the maximum value found. It is initialized
1952 // to the minimum representable value of the input element type. After being
1953 // populated by indexed_generic, this buffer is disgarded as only the index is
1954 // requested.
1955 //
1956 // The indexed_generic op updates both the maximum value and index if the
1957 // current value exceeds the running max.
1958 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1959 public:
1961 
1962  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1963  PatternRewriter &rewriter) const final {
1964  auto loc = argmaxOp.getLoc();
1965  Value input = argmaxOp.input();
1966  auto inputTy = input.getType().cast<ShapedType>();
1967  auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
1968  auto inElementTy = inputTy.getElementType();
1969  auto outElementTy = resultTy.getElementType();
1970  int axis = argmaxOp.axis();
1971  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1972 
1973  if (!inputTy.hasStaticShape())
1974  return rewriter.notifyMatchFailure(
1975  argmaxOp,
1976  "tosa.arg_max to linalg.* requires statically shaped input");
1977 
1978  if (!outElementTy.isa<IntegerType>())
1979  return rewriter.notifyMatchFailure(
1980  argmaxOp,
1981  "tosa.arg_max to linalg.* requires integer-like result type");
1982 
1983  // First fill the output buffer for the index.
1984  auto initTensorIdx =
1985  rewriter
1986  .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
1987  resultTy.getShape(), outElementTy)
1988  .result();
1989  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1990  loc, rewriter.getIntegerAttr(outElementTy, 0));
1991  auto filledTensorIdx =
1992  rewriter.create<linalg::FillOp>(loc, fillValueIdx, initTensorIdx)
1993  .result();
1994 
1995  // Second fill the output buffer for the running max.
1996  auto initTensorMax =
1997  rewriter
1998  .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
1999  resultTy.getShape(), inElementTy)
2000  .result();
2001  auto fillValueMaxAttr =
2002  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2003 
2004  if (!fillValueMaxAttr)
2005  return rewriter.notifyMatchFailure(
2006  argmaxOp, "unsupported tosa.argmax element type");
2007 
2008  auto fillValueMax =
2009  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2010  auto filledTensorMax =
2011  rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
2012  .result();
2013 
2014  // We need to reduce along the arg-max axis, with parallel operations along
2015  // the rest.
2016  SmallVector<StringRef, 4> iteratorTypes;
2017  iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
2018  iteratorTypes[axis] = getReductionIteratorTypeName();
2019 
2020  SmallVector<AffineExpr, 2> srcExprs;
2021  SmallVector<AffineExpr, 2> dstExprs;
2022  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2023  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2024  if (axis != i)
2025  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2026  }
2027 
2028  bool didEncounterError = false;
2029  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
2030  auto linalgOp = rewriter.create<linalg::GenericOp>(
2031  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2032  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2033  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2034  ValueRange blockArgs) {
2035  auto newValue = blockArgs[0];
2036  auto oldIndex = blockArgs[1];
2037  auto oldValue = blockArgs[2];
2038 
2039  Value newIndex = rewriter.create<arith::IndexCastOp>(
2040  nestedLoc, oldIndex.getType(),
2041  rewriter.create<linalg::IndexOp>(loc, axis));
2042 
2043  Value predicate;
2044  if (inElementTy.isa<FloatType>()) {
2045  predicate = rewriter.create<arith::CmpFOp>(
2046  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2047  } else if (inElementTy.isa<IntegerType>()) {
2048  predicate = rewriter.create<arith::CmpIOp>(
2049  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2050  } else {
2051  didEncounterError = true;
2052  return;
2053  }
2054 
2055  auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
2056  newValue, oldValue);
2057  auto resultIndex = rewriter.create<mlir::SelectOp>(
2058  nestedLoc, predicate, newIndex, oldIndex);
2059  nestedBuilder.create<linalg::YieldOp>(
2060  nestedLoc, ValueRange({resultIndex, resultMax}));
2061  });
2062 
2063  if (didEncounterError)
2064  return rewriter.notifyMatchFailure(
2065  argmaxOp, "unsupported tosa.argmax element type");
2066 
2067  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2068  return success();
2069  }
2070 };
2071 
2072 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2073 public:
2076  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2077  ConversionPatternRewriter &rewriter) const final {
2078  auto input = adaptor.getOperands()[0];
2079  auto indices = adaptor.getOperands()[1];
2080 
2081  auto resultTy = op.getType().cast<ShapedType>();
2082 
2083  auto dynamicDimsOr =
2084  checkHasDynamicBatchDims(rewriter, op, {input, indices, op.output()});
2085  if (!dynamicDimsOr.hasValue())
2086  return failure();
2087  SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
2088 
2089  auto resultElementTy = resultTy.getElementType();
2090 
2091  auto loc = op.getLoc();
2092 
2093  auto initTensor =
2094  rewriter
2095  .create<linalg::InitTensorOp>(loc, dynamicDims, resultTy.getShape(),
2096  resultElementTy)
2097  .result();
2098 
2099  SmallVector<AffineMap, 2> affineMaps = {
2101  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2102  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2103  rewriter.getContext()),
2104  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2105 
2106  auto genericOp = rewriter.create<linalg::GenericOp>(
2107  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2108  ValueRange{initTensor}, affineMaps,
2109  getNParallelLoopsAttrs(resultTy.getRank()),
2110  [&](OpBuilder &b, Location loc, ValueRange args) {
2111  auto indexValue = args[0];
2112  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2113  Value index1 = rewriter.create<arith::IndexCastOp>(
2114  loc, rewriter.getIndexType(), indexValue);
2115  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2116  Value extract = rewriter.create<tensor::ExtractOp>(
2117  loc, input, ValueRange{index0, index1, index2});
2118  rewriter.create<linalg::YieldOp>(loc, extract);
2119  });
2120  rewriter.replaceOp(op, genericOp.getResult(0));
2121  return success();
2122  }
2123 };
2124 
2125 // Lowerings the TableOp to a series of gathers and numerica operations. This
2126 // includes interpolation between the high/low values. For the I8 varient, this
2127 // simplifies to a single gather operation.
2128 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2129 public:
2131 
2132  LogicalResult matchAndRewrite(tosa::TableOp op,
2133  PatternRewriter &rewriter) const final {
2134  auto loc = op.getLoc();
2135  Value input = op.input();
2136  Value table = op.table();
2137  auto inputTy = input.getType().cast<ShapedType>();
2138  auto tableTy = table.getType().cast<ShapedType>();
2139  auto resultTy = op.getType().cast<ShapedType>();
2140 
2141  if (!inputTy.hasStaticShape())
2142  return rewriter.notifyMatchFailure(
2143  op, "require input type to have static shape");
2144 
2145  auto inputElementTy = inputTy.getElementType();
2146  auto tableElementTy = tableTy.getElementType();
2147  auto resultElementTy = resultTy.getElementType();
2148 
2149  auto initTensor =
2150  rewriter
2151  .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
2152  resultTy.getShape(), resultElementTy)
2153  .result();
2154 
2155  SmallVector<AffineMap, 2> affineMaps = {
2156  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2157  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2158 
2159  auto genericOp = rewriter.create<linalg::GenericOp>(
2160  loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
2161  getNParallelLoopsAttrs(resultTy.getRank()));
2162  rewriter.replaceOp(op, genericOp.getResult(0));
2163 
2164  {
2165  OpBuilder::InsertionGuard regionGuard(rewriter);
2166  Block *block = rewriter.createBlock(
2167  &genericOp.region(), genericOp.region().end(),
2168  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2169 
2170  auto inputValue = block->getArgument(0);
2171  rewriter.setInsertionPointToStart(block);
2172  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2173  resultElementTy.isInteger(8)) {
2174  Value index = rewriter.create<arith::IndexCastOp>(
2175  loc, rewriter.getIndexType(), inputValue);
2176  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2177  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2178  index, offset);
2179  Value extract =
2180  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2181  rewriter.create<linalg::YieldOp>(loc, extract);
2182  return success();
2183  }
2184 
2185  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2186  resultElementTy.isInteger(32)) {
2187  Value extend = rewriter.create<arith::ExtSIOp>(
2188  loc, rewriter.getI32Type(), inputValue);
2189 
2190  auto offset = rewriter.create<arith::ConstantOp>(
2191  loc, rewriter.getI32IntegerAttr(32768));
2192  auto seven = rewriter.create<arith::ConstantOp>(
2193  loc, rewriter.getI32IntegerAttr(7));
2194  auto one = rewriter.create<arith::ConstantOp>(
2195  loc, rewriter.getI32IntegerAttr(1));
2196  auto b1111111 = rewriter.create<arith::ConstantOp>(
2197  loc, rewriter.getI32IntegerAttr(127));
2198 
2199  // Compute the index and fractional part from the input value:
2200  // value = value + 32768
2201  // index = value >> 7;
2202  // fraction = 0x01111111 & value
2203  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2204  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2205  Value fraction =
2206  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2207 
2208  // Extract the base and next values from the table.
2209  // base = (int32_t) table[index];
2210  // next = (int32_t) table[index + 1];
2211  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2212 
2213  index = rewriter.create<arith::IndexCastOp>(
2214  loc, rewriter.getIndexType(), index);
2215  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2216  loc, rewriter.getIndexType(), indexPlusOne);
2217 
2218  Value base =
2219  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2220  Value next = rewriter.create<tensor::ExtractOp>(
2221  loc, table, ValueRange{indexPlusOne});
2222 
2223  base =
2224  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2225  next =
2226  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2227 
2228  // Use the fractional part to interpolate between the input values:
2229  // result = (base << 7) + (next - base) * fraction
2230  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2231  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2232  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2233  Value result =
2234  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2235 
2236  rewriter.create<linalg::YieldOp>(loc, result);
2237 
2238  return success();
2239  }
2240  }
2241 
2242  return rewriter.notifyMatchFailure(
2243  op, "unable to create body for tosa.table op");
2244  }
2245 };
2246 
2247 } // namespace
2248 
2250  RewritePatternSet *patterns) {
2251  patterns->add<
2252  // clang-format off
2253  PointwiseConverter<tosa::AddOp>,
2254  PointwiseConverter<tosa::SubOp>,
2255  PointwiseConverter<tosa::MulOp>,
2256  PointwiseConverter<tosa::DivOp>,
2257  PointwiseConverter<tosa::NegateOp>,
2258  PointwiseConverter<tosa::PowOp>,
2259  PointwiseConverter<tosa::ReciprocalOp>,
2260  PointwiseConverter<tosa::RsqrtOp>,
2261  PointwiseConverter<tosa::LogOp>,
2262  PointwiseConverter<tosa::ExpOp>,
2263  PointwiseConverter<tosa::AbsOp>,
2264  PointwiseConverter<tosa::TanhOp>,
2265  PointwiseConverter<tosa::BitwiseAndOp>,
2266  PointwiseConverter<tosa::BitwiseOrOp>,
2267  PointwiseConverter<tosa::BitwiseNotOp>,
2268  PointwiseConverter<tosa::BitwiseXorOp>,
2269  PointwiseConverter<tosa::LogicalAndOp>,
2270  PointwiseConverter<tosa::LogicalNotOp>,
2271  PointwiseConverter<tosa::LogicalOrOp>,
2272  PointwiseConverter<tosa::LogicalXorOp>,
2273  PointwiseConverter<tosa::CastOp>,
2274  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2275  PointwiseConverter<tosa::LogicalRightShiftOp>,
2276  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2277  PointwiseConverter<tosa::ClzOp>,
2278  PointwiseConverter<tosa::SelectOp>,
2279  PointwiseConverter<tosa::GreaterOp>,
2280  PointwiseConverter<tosa::GreaterEqualOp>,
2281  PointwiseConverter<tosa::EqualOp>,
2282  PointwiseConverter<tosa::MaximumOp>,
2283  PointwiseConverter<tosa::MinimumOp>,
2284  PointwiseConverter<tosa::CeilOp>,
2285  PointwiseConverter<tosa::FloorOp>,
2286  PointwiseConverter<tosa::ClampOp>,
2287  PointwiseConverter<tosa::ReluNOp>,
2288  PointwiseConverter<tosa::SigmoidOp>,
2289  IdentityNConverter<tosa::IdentityOp>,
2290  ReduceConverter<tosa::ReduceAllOp>,
2291  ReduceConverter<tosa::ReduceAnyOp>,
2292  ReduceConverter<tosa::ReduceMinOp>,
2293  ReduceConverter<tosa::ReduceMaxOp>,
2294  ReduceConverter<tosa::ReduceSumOp>,
2295  ReduceConverter<tosa::ReduceProdOp>,
2296  ArgMaxConverter,
2297  ConcatConverter,
2298  GatherConverter,
2299  PadConverter,
2300  ReshapeConverterCollapse,
2301  ReshapeConverterExpand,
2302  ReshapeConverterCollapseExpand,
2303  RescaleConverter,
2304  ResizeConverter,
2305  ReverseConverter,
2306  TableConverter,
2307  TileConverter,
2308  TransposeConverter>(patterns->getContext());
2309  // clang-format on
2310 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
U cast() const
Definition: Location.h:67
MLIRContext * getContext() const
Definition: Builders.h:54
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
U cast() const
Definition: Attributes.h:123
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:444
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:308
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static Value min(ImplicitLocOpBuilder &builder, Value a, Value b)
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:41
Block represents an ordered list of Operations.
Definition: Block.h:29
static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:244
Value getOperand(unsigned idx)
Definition: Operation.h:219
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
unsigned getNumOperands()
Definition: Operation.h:215
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
void getValuesFromIntArrayAttribute(ArrayAttr attr, SmallVector< T > &arrayValues)
FloatType getF32Type()
Definition: Builders.cpp:40
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
SmallVector< StringRef > getNParallelLoopsAttrs(unsigned nParallelLoops)
LogicalResult notifyMatchFailure(Operation *op, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
Optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:61
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:343
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:244
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:166
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
PadOp createPadScalarOp(Type type, Value source, Value pad, ArrayRef< OpFoldResult > low, ArrayRef< OpFoldResult > high, bool nofold, Location loc, OpBuilder &builder)
Definition: Utils.cpp:21
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:231
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:491
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
static bool findIntermediateShape(ArrayRef< int64_t > lhsShape, ArrayRef< int64_t > rhsShape, SmallVector< int64_t > &intermediateShape, bool isDynamic)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, PatternRewriter &rewriter)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
static int resultIndex(int i)
Definition: Operator.cpp:308
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
IndexType getIndexType()
Definition: Builders.cpp:48
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
static bool createReassociationMapsForCollapse(PatternRewriter &rewriter, ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape, SmallVector< ReassociationExprs, 4 > &reassociationMap, bool isDynamic)
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:285
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:802
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter)
bool isa() const
Definition: Types.h:234
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:185
result_range getResults()
Definition: Operation.h:284
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:323
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
MLIRContext * getContext() const
Definition: PatternMatch.h:906
IntegerType getI32Type()
Definition: Builders.cpp:54
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:250
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)