MLIR  18.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 
34 #include <numeric>
35 
36 using namespace mlir;
37 using namespace mlir::tosa;
38 
39 template <typename T>
40 static arith::ConstantOp
41 createConstFromIntAttribute(Operation *op, const std::string &attrName,
42  Type requiredAttrType, OpBuilder &rewriter) {
43  auto castedN = static_cast<T>(
44  cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
45  return rewriter.create<arith::ConstantOp>(
46  op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
47 }
48 
49 static Value
51  ArrayRef<Type> resultTypes,
52  PatternRewriter &rewriter) {
53  Location loc = op->getLoc();
54  auto elementTy =
55  cast<ShapedType>(op->getOperand(0).getType()).getElementType();
56 
57  // tosa::AbsOp
58  if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
59  return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
60 
61  if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
62  auto zero = rewriter.create<arith::ConstantOp>(
63  loc, rewriter.getZeroAttr(elementTy));
64  auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
65  args[0], zero);
66  auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
67  return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
68  }
69 
70  // tosa::AddOp
71  if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
72  return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
73 
74  if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
75  return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
76 
77  // tosa::SubOp
78  if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
79  return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
80 
81  if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
82  return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
83 
84  // tosa::MulOp
85  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
86  if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
87  (void)rewriter.notifyMatchFailure(op,
88  "Cannot have shift value for float");
89  return nullptr;
90  }
91  return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
92  }
93 
94  // tosa::DivOp
95  if (isa<tosa::DivOp>(op) && isa<IntegerType>(elementTy))
96  return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
97 
98  // tosa::ReciprocalOp
99  if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
100  auto one =
101  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
102  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
103  }
104 
105  if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
106  Value a = args[0];
107  Value b = args[1];
108  auto shift =
109  cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
110  if (shift > 0) {
111  auto shiftConst =
112  rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
113  if (!a.getType().isInteger(32))
114  a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
115 
116  if (!b.getType().isInteger(32))
117  b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
118 
119  auto result = rewriter.create<tosa::ApplyScaleOp>(
120  loc, rewriter.getI32Type(), a, b, shiftConst,
121  rewriter.getBoolAttr(false));
122 
123  if (elementTy.isInteger(32))
124  return result;
125 
126  return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
127  }
128 
129  int aWidth = a.getType().getIntOrFloatBitWidth();
130  int bWidth = b.getType().getIntOrFloatBitWidth();
131  int cWidth = resultTypes[0].getIntOrFloatBitWidth();
132 
133  if (aWidth < cWidth)
134  a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
135  if (bWidth < cWidth)
136  b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
137 
138  return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
139  }
140 
141  // tosa::NegateOp
142  if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
143  return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
144 
145  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
146  !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
147  auto constant =
148  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
149  return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
150  }
151 
152  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
153  cast<tosa::NegateOp>(op).getQuantizationInfo()) {
154  auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
155  int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
156  int64_t inZp = quantizationInfo.value().getInputZp();
157  int64_t outZp = quantizationInfo.value().getOutputZp();
158 
159  // Compute the maximum value that can occur in the intermediate buffer.
160  int64_t zpAdd = inZp + outZp;
161  int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
162  std::abs(zpAdd) + 1;
163 
164  // Convert that maximum value into the maximum bitwidth needed to represent
165  // it. We assume 48-bit numbers may be supported further in the pipeline.
166  int intermediateBitWidth = 64;
167  if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
168  intermediateBitWidth = 16;
169  } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
170  intermediateBitWidth = 32;
171  } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
172  intermediateBitWidth = 48;
173  }
174 
175  Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
176  Value zpAddValue = rewriter.create<arith::ConstantOp>(
177  loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
178 
179  // The negation can be applied by doing:
180  // outputValue = inZp + outZp - inputValue
181  auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
182  auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
183 
184  // Clamp to the negation range.
185  Value min = rewriter.create<arith::ConstantIntOp>(
186  loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
187  intermediateType);
188  Value max = rewriter.create<arith::ConstantIntOp>(
189  loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
190  intermediateType);
191  auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
192 
193  // Truncate to the final value.
194  return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
195  }
196 
197  // tosa::BitwiseAndOp
198  if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
199  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
200 
201  // tosa::BitwiseOrOp
202  if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
203  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
204 
205  // tosa::BitwiseNotOp
206  if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
207  auto allOnesAttr = rewriter.getIntegerAttr(
208  elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
209  auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
210  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
211  }
212 
213  // tosa::BitwiseXOrOp
214  if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
215  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
216 
217  // tosa::LogicalLeftShiftOp
218  if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
219  return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
220 
221  // tosa::LogicalRightShiftOp
222  if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
223  return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
224 
225  // tosa::ArithmeticRightShiftOp
226  if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
227  auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
228  auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
229  if (!round) {
230  return result;
231  }
232 
233  Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
234  auto one =
235  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
236  auto zero =
237  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
238  auto i1one =
239  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
240 
241  // Checking that input2 != 0
242  auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
243  loc, arith::CmpIPredicate::sgt, args[1], zero);
244 
245  // Checking for the last bit of input1 to be 1
246  auto subtract =
247  rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
248  auto shifted =
249  rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
250  ->getResults();
251  auto truncated =
252  rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
253  auto isInputOdd =
254  rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
255 
256  auto shouldRound = rewriter.create<arith::AndIOp>(
257  loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
258  auto extended =
259  rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
260  return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
261  }
262 
263  // tosa::ClzOp
264  if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
265  return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
266  }
267 
268  // tosa::LogicalAnd
269  if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
270  return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
271 
272  // tosa::LogicalNot
273  if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
274  auto one = rewriter.create<arith::ConstantOp>(
275  loc, rewriter.getIntegerAttr(elementTy, 1));
276  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
277  }
278 
279  // tosa::LogicalOr
280  if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
281  return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
282 
283  // tosa::LogicalXor
284  if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
285  return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
286 
287  // tosa::PowOp
288  if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
289  return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
290 
291  // tosa::RsqrtOp
292  if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
293  return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
294 
295  // tosa::LogOp
296  if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
297  return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
298 
299  // tosa::ExpOp
300  if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
301  return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
302 
303  // tosa::TanhOp
304  if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
305  return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
306 
307  // tosa::ErfOp
308  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
309  return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
310 
311  // tosa::GreaterOp
312  if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
313  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
314  args[0], args[1]);
315 
316  if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
317  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
318  args[0], args[1]);
319 
320  // tosa::GreaterEqualOp
321  if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
322  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
323  args[0], args[1]);
324 
325  if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
326  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
327  args[0], args[1]);
328 
329  // tosa::EqualOp
330  if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
331  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
332  args[0], args[1]);
333 
334  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
335  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
336  args[0], args[1]);
337 
338  // tosa::SelectOp
339  if (isa<tosa::SelectOp>(op)) {
340  elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
341  if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
342  return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
343  }
344 
345  // tosa::MaximumOp
346  if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
347  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
348  }
349 
350  if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
351  auto predicate = rewriter.create<arith::CmpIOp>(
352  loc, arith::CmpIPredicate::sgt, args[0], args[1]);
353  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
354  }
355 
356  // tosa::MinimumOp
357  if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
358  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
359  }
360 
361  if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
362  auto predicate = rewriter.create<arith::CmpIOp>(
363  loc, arith::CmpIPredicate::slt, args[0], args[1]);
364  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
365  }
366 
367  // tosa::CeilOp
368  if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
369  return rewriter.create<math::CeilOp>(loc, resultTypes, args);
370 
371  // tosa::FloorOp
372  if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
373  return rewriter.create<math::FloorOp>(loc, resultTypes, args);
374 
375  // tosa::ClampOp
376  if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
377  bool losesInfo = false;
378  APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
379  APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
380  minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
381  APFloat::rmNearestTiesToEven, &losesInfo);
382  maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
383  APFloat::rmNearestTiesToEven, &losesInfo);
384  auto min = rewriter.create<arith::ConstantOp>(
385  loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
386  auto max = rewriter.create<arith::ConstantOp>(
387  loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
388  return clampFloatHelper(loc, args[0], min, max, rewriter);
389  }
390 
391  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
392  auto intTy = cast<IntegerType>(elementTy);
393  int32_t min = static_cast<int32_t>(
394  cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
395  int32_t max = static_cast<int32_t>(
396  cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
397 
398  if (intTy.isUnsignedInteger()) {
399  min = std::max<int32_t>(min, 0);
400  max = std::min<int32_t>(
401  max,
402  APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
403  } else {
404  min = std::max<int32_t>(
405  min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
406  .getSExtValue());
407  max = std::min<int32_t>(
408  max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
409  .getSExtValue());
410  }
411 
412  auto minVal = rewriter.create<arith::ConstantIntOp>(
413  loc, min, intTy.getIntOrFloatBitWidth());
414  auto maxVal = rewriter.create<arith::ConstantIntOp>(
415  loc, max, intTy.getIntOrFloatBitWidth());
416  return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
417  }
418 
419  // tosa::SigmoidOp
420  if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
421  auto one =
422  rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
423  auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
424  auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
425  auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
426  return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
427  }
428 
429  // tosa::CastOp
430  if (isa<tosa::CastOp>(op)) {
431  Type srcTy = elementTy;
432  Type dstTy = resultTypes.front();
433  bool bitExtend =
435 
436  if (srcTy == dstTy)
437  return args.front();
438 
439  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
440  return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
441  std::nullopt);
442 
443  if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
444  return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
445  std::nullopt);
446 
447  // 1-bit integers need to be treated as signless.
448  if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
449  return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
450  std::nullopt);
451 
452  if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
453  return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
454  std::nullopt);
455 
456  // Unsigned integers need an unrealized cast so that they can be passed
457  // to UIToFP.
458  if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
459  auto unrealizedCast =
460  rewriter
461  .create<UnrealizedConversionCastOp>(
462  loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
463  args[0])
464  .getResult(0);
465  return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
466  unrealizedCast);
467  }
468 
469  // All other si-to-fp conversions should be handled by SIToFP.
470  if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
471  return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
472  std::nullopt);
473 
474  // Casting to boolean, floats need to only be checked as not-equal to zero.
475  if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
476  Value zero = rewriter.create<arith::ConstantOp>(
477  loc, rewriter.getFloatAttr(srcTy, 0.0));
478  return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
479  args.front(), zero);
480  }
481 
482  if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
483  auto intMin = rewriter.create<arith::ConstantOp>(
484  loc, rewriter.getFloatAttr(
485  getElementTypeOrSelf(srcTy),
486  APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
487  .getSExtValue()));
488 
489  auto intMax = rewriter.create<arith::ConstantOp>(
490  loc, rewriter.getFloatAttr(
491  getElementTypeOrSelf(srcTy),
492  APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
493  .getSExtValue()));
494 
495  auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
496 
497  auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
498 
499  return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
500  }
501 
502  // Casting to boolean, integers need to only be checked as not-equal to
503  // zero.
504  if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
505  Value zero = rewriter.create<arith::ConstantIntOp>(
506  loc, 0, srcTy.getIntOrFloatBitWidth());
507  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
508  args.front(), zero);
509  }
510 
511  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
512  return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
513  std::nullopt);
514 
515  if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
516  return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
517  }
518  }
519 
520  (void)rewriter.notifyMatchFailure(
521  op, "unhandled op for linalg body calculation for elementwise op");
522  return nullptr;
523 }
524 
525 static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
526  int64_t rank) {
527  // No need to expand if we are already at the desired rank
528  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
529  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
530  int64_t numExtraDims = rank - shapedType.getRank();
531  assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
532  if (!numExtraDims)
533  return tensor;
534 
535  // Compute reassociation indices
536  SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
537  shapedType.getRank());
538  int64_t index = 0;
539  for (index = 0; index <= numExtraDims; index++)
540  reassociationIndices[0].push_back(index);
541  for (size_t position = 1; position < reassociationIndices.size(); position++)
542  reassociationIndices[position].push_back(index++);
543 
544  // Compute result type
545  SmallVector<int64_t> resultShape;
546  for (index = 0; index < numExtraDims; index++)
547  resultShape.push_back(1);
548  for (auto size : shapedType.getShape())
549  resultShape.push_back(size);
550  auto resultType =
551  RankedTensorType::get(resultShape, shapedType.getElementType());
552 
553  // Emit 'tensor.expand_shape' op
554  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
555  reassociationIndices);
556 }
557 
559  Location loc, Operation *operation) {
560  auto rank =
561  operation->getResultTypes().front().cast<RankedTensorType>().getRank();
562  return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
563  return expandRank(rewriter, loc, operand, rank);
564  });
565 }
566 
568 
569 // Emit an 'arith.constant' op for the given index if it has not been created
570 // yet, or return an existing constant. This will prevent an excessive creation
571 // of redundant constants, easing readability of emitted code for unit tests.
573  IndexPool &indexPool, int64_t index) {
574  auto [it, inserted] = indexPool.try_emplace(index);
575  if (inserted)
576  it->second =
577  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
578  return it->second;
579 }
580 
582  IndexPool &indexPool, Value tensor, int64_t index) {
583  auto indexValue = createIndex(rewriter, loc, indexPool, index);
584  return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
585 }
586 
588  IndexPool &indexPool, Value tensor,
589  int64_t index) {
590  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
591  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
592  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
593  if (shapedType.isDynamicDim(index))
594  return getTensorDim(rewriter, loc, indexPool, tensor, index);
595  return rewriter.getIndexAttr(shapedType.getDimSize(index));
596 }
597 
598 static bool operandsAndResultsRanked(Operation *operation) {
599  auto isRanked = [](Value value) {
600  return isa<RankedTensorType>(value.getType());
601  };
602  return llvm::all_of(operation->getOperands(), isRanked) &&
603  llvm::all_of(operation->getResults(), isRanked);
604 }
605 
606 // Compute the runtime dimension size for dimension 'dim' of the output by
607 // inspecting input 'operands', all of which are expected to have the same rank.
608 // This function returns a pair {targetSize, masterOperand}.
609 //
610 // The runtime size of the output dimension is returned either as a statically
611 // computed attribute or as a runtime SSA value.
612 //
613 // If the target size was inferred directly from one dominating operand, that
614 // operand is returned in 'masterOperand'. If the target size is inferred from
615 // multiple operands, 'masterOperand' is set to nullptr.
616 static std::pair<OpFoldResult, Value>
618  ValueRange operands, int64_t dim) {
619  // If any input operand contains a static size greater than 1 for this
620  // dimension, that is the target size. An occurrence of an additional static
621  // dimension greater than 1 with a different value is undefined behavior.
622  for (auto operand : operands) {
623  auto size = operand.getType().cast<RankedTensorType>().getDimSize(dim);
624  if (!ShapedType::isDynamic(size) && size > 1)
625  return {rewriter.getIndexAttr(size), operand};
626  }
627 
628  // Filter operands with dynamic dimension
629  auto operandsWithDynamicDim =
630  llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
631  return operand.getType().cast<RankedTensorType>().isDynamicDim(dim);
632  }));
633 
634  // If no operand has a dynamic dimension, it means all sizes were 1
635  if (operandsWithDynamicDim.empty())
636  return {rewriter.getIndexAttr(1), operands.front()};
637 
638  // Emit code that computes the runtime size for this dimension. If there is
639  // only one operand with a dynamic dimension, it is considered the master
640  // operand that determines the runtime size of the output dimension.
641  auto targetSize =
642  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
643  if (operandsWithDynamicDim.size() == 1)
644  return {targetSize, operandsWithDynamicDim[0]};
645 
646  // Calculate maximum size among all dynamic dimensions
647  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
648  auto nextSize =
649  getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
650  targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
651  }
652  return {targetSize, nullptr};
653 }
654 
655 // Compute the runtime output size for all dimensions. This function returns
656 // a pair {targetShape, masterOperands}.
657 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
659  IndexPool &indexPool, ValueRange operands) {
660  assert(!operands.empty());
661  auto rank = operands.front().getType().cast<RankedTensorType>().getRank();
662  SmallVector<OpFoldResult> targetShape;
663  SmallVector<Value> masterOperands;
664  for (auto dim : llvm::seq<int64_t>(0, rank)) {
665  auto [targetSize, masterOperand] =
666  computeTargetSize(rewriter, loc, indexPool, operands, dim);
667  targetShape.push_back(targetSize);
668  masterOperands.push_back(masterOperand);
669  }
670  return {targetShape, masterOperands};
671 }
672 
674  IndexPool &indexPool, Value operand,
675  int64_t dim, OpFoldResult targetSize,
676  Value masterOperand) {
677  // Nothing to do if this is a static dimension
678  auto rankedTensorType = operand.getType().cast<RankedTensorType>();
679  if (!rankedTensorType.isDynamicDim(dim))
680  return operand;
681 
682  // If the target size for this dimension was directly inferred by only taking
683  // this operand into account, there is no need to broadcast. This is an
684  // optimization that will prevent redundant control flow, and constitutes the
685  // main motivation for tracking "master operands".
686  if (operand == masterOperand)
687  return operand;
688 
689  // Affine maps for 'linalg.generic' op
690  auto rank = rankedTensorType.getRank();
691  SmallVector<AffineExpr> affineExprs;
692  for (auto index : llvm::seq<int64_t>(0, rank)) {
693  auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
694  : rewriter.getAffineDimExpr(index);
695  affineExprs.push_back(affineExpr);
696  }
697  auto broadcastAffineMap =
698  AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
699  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
700  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
701 
702  // Check if broadcast is necessary
703  auto one = createIndex(rewriter, loc, indexPool, 1);
704  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
705  auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
706  loc, arith::CmpIPredicate::eq, runtimeSize, one);
707 
708  // Emit 'then' region of 'scf.if'
709  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
710  // Emit 'tensor.empty' op
711  SmallVector<OpFoldResult> outputTensorShape;
712  for (auto index : llvm::seq<int64_t>(0, rank)) {
713  auto size = index == dim ? targetSize
714  : getOrFoldTensorDim(rewriter, loc, indexPool,
715  operand, index);
716  outputTensorShape.push_back(size);
717  }
718  Value outputTensor = opBuilder.create<tensor::EmptyOp>(
719  loc, outputTensorShape, rankedTensorType.getElementType());
720 
721  // Emit 'linalg.generic' op
722  auto resultTensor =
723  opBuilder
724  .create<linalg::GenericOp>(
725  loc, outputTensor.getType(), operand, outputTensor, affineMaps,
727  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
728  // Emit 'linalg.yield' op
729  opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
730  })
731  .getResult(0);
732 
733  // Cast to original operand type if necessary
734  auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
735  loc, operand.getType(), resultTensor);
736 
737  // Emit 'scf.yield' op
738  opBuilder.create<scf::YieldOp>(loc, castResultTensor);
739  };
740 
741  // Emit 'else' region of 'scf.if'
742  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
743  opBuilder.create<scf::YieldOp>(loc, operand);
744  };
745 
746  // Emit 'scf.if' op
747  auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
748  emitThenRegion, emitElseRegion);
749  return ifOp.getResult(0);
750 }
751 
753  IndexPool &indexPool, Value operand,
754  ArrayRef<OpFoldResult> targetShape,
755  ArrayRef<Value> masterOperands) {
756  size_t rank = operand.getType().cast<RankedTensorType>().getRank();
757  assert(targetShape.size() == rank);
758  assert(masterOperands.size() == rank);
759  for (auto index : llvm::seq<int64_t>(0, rank))
760  operand =
761  broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
762  targetShape[index], masterOperands[index]);
763  return operand;
764 }
765 
766 static SmallVector<Value>
768  IndexPool &indexPool, ValueRange operands,
769  ArrayRef<OpFoldResult> targetShape,
770  ArrayRef<Value> masterOperands) {
771  // No need to broadcast for unary operations
772  if (operands.size() == 1)
773  return operands;
774 
775  // Broadcast dynamic dimensions operand by operand
776  return llvm::map_to_vector(operands, [&](Value operand) {
777  return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
778  targetShape, masterOperands);
779  });
780 }
781 
782 static LogicalResult
784  Operation *operation, ValueRange operands,
785  ArrayRef<OpFoldResult> targetShape) {
786  // Generate output tensor
787  auto resultType =
788  operation->getResultTypes().front().cast<RankedTensorType>();
789  Value outputTensor = rewriter.create<tensor::EmptyOp>(
790  loc, targetShape, resultType.getElementType());
791 
792  // Create affine maps. Input affine maps broadcast static dimensions of size
793  // 1. The output affine map is an identity map.
794  //
795  auto rank = resultType.getRank();
796  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
797  auto shape = cast<ShapedType>(operand.getType()).getShape();
798  SmallVector<AffineExpr> affineExprs;
799  for (auto it : llvm::enumerate(shape)) {
800  auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
801  : rewriter.getAffineDimExpr(it.index());
802  affineExprs.push_back(affineExpr);
803  }
804  return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
805  });
806  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
807 
808  // Emit 'linalg.generic' op
809  bool encounteredError = false;
810  auto linalgOp = rewriter.create<linalg::GenericOp>(
811  loc, outputTensor.getType(), operands, outputTensor, affineMaps,
813  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
815  operation, blockArgs.take_front(operation->getNumOperands()),
816  {resultType.getElementType()}, rewriter);
817  if (!opResult) {
818  encounteredError = true;
819  return;
820  }
821  opBuilder.create<linalg::YieldOp>(loc, opResult);
822  });
823  if (encounteredError)
824  return rewriter.notifyMatchFailure(
825  operation, "unable to create linalg.generic body for elementwise op");
826 
827  // Cast 'linalg.generic' result into original result type if needed
828  auto castResult = rewriter.createOrFold<tensor::CastOp>(
829  loc, resultType, linalgOp->getResult(0));
830  rewriter.replaceOp(operation, castResult);
831  return success();
832 }
833 
834 static LogicalResult
836  PatternRewriter &rewriter) {
837 
838  // Collect op properties
839  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
840  assert(operation->getNumOperands() >= 1 &&
841  "elementwise op expects at least 1 operand");
842  if (!operandsAndResultsRanked(operation))
843  return rewriter.notifyMatchFailure(operation,
844  "Unranked tensors not supported");
845 
846  // Lower operation
847  IndexPool indexPool;
848  auto loc = operation->getLoc();
849  auto expandedOperands = expandInputRanks(rewriter, loc, operation);
850  auto [targetShape, masterOperands] =
851  computeTargetShape(rewriter, loc, indexPool, expandedOperands);
852  auto broadcastOperands = broadcastDynamicDimensions(
853  rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
854  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
855  targetShape);
856 }
857 
858 // Returns the constant initial value for a given reduction operation. The
859 // attribute type varies depending on the element type required.
860 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
861  PatternRewriter &rewriter) {
862  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
863  return rewriter.getFloatAttr(elementTy, 0.0);
864 
865  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
866  return rewriter.getIntegerAttr(elementTy, 0);
867 
868  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
869  return rewriter.getFloatAttr(elementTy, 1.0);
870 
871  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
872  return rewriter.getIntegerAttr(elementTy, 1);
873 
874  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
875  return rewriter.getFloatAttr(
876  elementTy, APFloat::getLargest(
877  cast<FloatType>(elementTy).getFloatSemantics(), false));
878 
879  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
880  return rewriter.getIntegerAttr(
881  elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
882 
883  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
884  return rewriter.getFloatAttr(
885  elementTy, APFloat::getLargest(
886  cast<FloatType>(elementTy).getFloatSemantics(), true));
887 
888  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
889  return rewriter.getIntegerAttr(
890  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
891 
892  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
893  return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
894 
895  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
896  return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
897 
898  if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
899  return rewriter.getFloatAttr(
900  elementTy, APFloat::getLargest(
901  cast<FloatType>(elementTy).getFloatSemantics(), true));
902 
903  if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
904  return rewriter.getIntegerAttr(
905  elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
906 
907  return {};
908 }
909 
910 // Creates the body calculation for a reduction. The operations vary depending
911 // on the input type.
913  ValueRange args,
914  Type elementTy,
915  PatternRewriter &rewriter) {
916  Location loc = op->getLoc();
917  if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
918  return rewriter.create<arith::AddFOp>(loc, args);
919  }
920 
921  if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
922  return rewriter.create<arith::AddIOp>(loc, args);
923  }
924 
925  if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
926  return rewriter.create<arith::MulFOp>(loc, args);
927  }
928 
929  if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
930  return rewriter.create<arith::MulIOp>(loc, args);
931  }
932 
933  if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
934  return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
935  }
936 
937  if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
938  auto predicate = rewriter.create<arith::CmpIOp>(
939  loc, arith::CmpIPredicate::slt, args[0], args[1]);
940  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
941  }
942 
943  if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
944  return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
945  }
946 
947  if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
948  auto predicate = rewriter.create<arith::CmpIOp>(
949  loc, arith::CmpIPredicate::sgt, args[0], args[1]);
950  return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
951  }
952 
953  if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
954  return rewriter.create<arith::AndIOp>(loc, args);
955 
956  if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
957  return rewriter.create<arith::OrIOp>(loc, args);
958 
959  return {};
960 }
961 
962 // Performs the match and rewrite for reduction operations. This includes
963 // declaring a correctly sized initial value, and the linalg.generic operation
964 // that reduces across the specified axis.
966  PatternRewriter &rewriter) {
967  auto loc = op->getLoc();
968  auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
969  auto resultTy = cast<ShapedType>(op->getResult(0).getType());
970  auto elementTy = resultTy.getElementType();
971  Value input = op->getOperand(0);
972 
973  SmallVector<int64_t> reduceShape;
974  SmallVector<Value> dynDims;
975  for (unsigned i = 0; i < inputTy.getRank(); i++) {
976  if (axis != i) {
977  reduceShape.push_back(inputTy.getDimSize(i));
978  if (inputTy.isDynamicDim(i))
979  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
980  }
981  }
982 
983  Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
984 
985  // First fill the output buffer with the init value.
986  auto emptyTensor =
987  rewriter
988  .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
989  dynDims)
990  .getResult();
991 
992  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
993  if (!fillValueAttr)
994  return rewriter.notifyMatchFailure(
995  op, "No initial value found for reduction operation");
996 
997  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
998  auto filledTensor = rewriter
999  .create<linalg::FillOp>(loc, ValueRange{fillValue},
1000  ValueRange{emptyTensor})
1001  .result();
1002 
1003  SmallVector<AffineExpr, 2> srcExprs;
1004  SmallVector<AffineExpr, 2> dstExprs;
1006  for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
1007  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1008 
1009  iteratorTypes.push_back(axis == i ? utils::IteratorType::reduction
1010  : utils::IteratorType::parallel);
1011  if (axis != i)
1012  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1013  }
1014 
1015  bool didEncounterError = false;
1016  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
1017  auto linalgOp = rewriter.create<linalg::GenericOp>(
1018  loc, reduceTy, input, filledTensor, maps, iteratorTypes,
1019  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1021  op, blockArgs, elementTy, rewriter);
1022  if (result)
1023  didEncounterError = true;
1024 
1025  nestedBuilder.create<linalg::YieldOp>(loc, result);
1026  });
1027 
1028  if (!didEncounterError)
1029  return rewriter.notifyMatchFailure(
1030  op, "unable to create linalg.generic body for reduce op");
1031 
1032  SmallVector<ReassociationExprs, 4> reassociationMap;
1033  uint64_t expandInputRank =
1034  cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1035  reassociationMap.resize(expandInputRank);
1036 
1037  for (uint64_t i = 0; i < expandInputRank; i++) {
1038  int32_t dimToPush = i > axis ? i + 1 : i;
1039  reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1040  }
1041 
1042  if (expandInputRank != 0) {
1043  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1044  reassociationMap[expandedDim].push_back(
1045  rewriter.getAffineDimExpr(expandedDim + 1));
1046  }
1047 
1048  // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1049  // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1050  // not have access to such information. This matters when handling dynamically
1051  // sized tensors.
1052  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1053  op, resultTy, linalgOp.getResults()[0], reassociationMap);
1054  return success();
1055 }
1056 
1057 namespace {
1058 
1059 template <typename SrcOp>
1060 class PointwiseConverter : public OpRewritePattern<SrcOp> {
1061 public:
1063 
1064  LogicalResult matchAndRewrite(SrcOp op,
1065  PatternRewriter &rewriter) const final {
1066  return elementwiseMatchAndRewriteHelper(op, rewriter);
1067  }
1068 };
1069 
1070 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1071 public:
1073 
1074  LogicalResult matchAndRewrite(tosa::TransposeOp op,
1075  PatternRewriter &rewriter) const final {
1076  DenseIntElementsAttr perms;
1077  if (!matchPattern(op.getPerms(), m_Constant(&perms))) {
1078  return rewriter.notifyMatchFailure(op, "unmatched permutation tensor");
1079  }
1080 
1081  auto loc = op.getLoc();
1082  auto input = op->getOperand(0);
1083  auto resultTy = cast<ShapedType>(op.getType());
1084 
1085  SmallVector<Value> dynDims;
1086  dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
1087 
1088  SmallVector<AffineExpr, 2> inputExprs;
1089  inputExprs.resize(resultTy.getRank());
1090  auto operandTy = cast<ShapedType>(input.getType());
1091  for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
1092  auto index = permutation.index();
1093  auto value = permutation.value().getZExtValue();
1094  if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1095  dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
1096  }
1097  inputExprs[value] = rewriter.getAffineDimExpr(index);
1098  }
1099 
1100  SmallVector<Value> filteredDims = condenseValues(dynDims);
1101 
1102  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1103  loc, resultTy.getShape(), resultTy.getElementType(), filteredDims);
1104 
1105  SmallVector<AffineMap, 2> affineMaps = {
1106  AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1107  rewriter.getContext()),
1108  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1109 
1110  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1111  op, resultTy, op.getInput1(), ValueRange{emptyTensor}, affineMaps,
1112  getNParallelLoopsAttrs(resultTy.getRank()),
1113  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1114  nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
1115  });
1116  return success();
1117  }
1118 };
1119 
1120 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1121 public:
1123 
1124  LogicalResult matchAndRewrite(tosa::RescaleOp op,
1125  PatternRewriter &rewriter) const final {
1126  auto loc = op.getLoc();
1127  auto input = op.getInput();
1128  auto inputTy = cast<ShapedType>(op.getInput().getType());
1129  auto outputTy = cast<ShapedType>(op.getOutput().getType());
1130  unsigned rank = inputTy.getRank();
1131 
1132  // This is an illegal configuration. terminate and log an error
1133  if (op.getDoubleRound() && !op.getScale32())
1134  return rewriter.notifyMatchFailure(
1135  op, "tosa.rescale requires scale32 for double_round to be true");
1136 
1137  SmallVector<Value> dynDims;
1138  for (int i = 0; i < outputTy.getRank(); i++) {
1139  if (outputTy.isDynamicDim(i)) {
1140  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1141  }
1142  }
1143 
1144  // The shift and multiplier values.
1145  SmallVector<int32_t> multiplierValues(op.getMultiplier());
1146  SmallVector<int8_t> shiftValues(op.getShift());
1147 
1148  // If we shift by more than the bitwidth, this just sets to 0.
1149  for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1150  if (shiftValues[i] > 63) {
1151  shiftValues[i] = 0;
1152  multiplierValues[i] = 0;
1153  }
1154  }
1155 
1156  // Double round only occurs if shift is greater than 31, check that this
1157  // is ever true.
1158  bool doubleRound =
1159  op.getDoubleRound() &&
1160  llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1161 
1162  SmallVector<AffineMap> indexingMaps = {
1163  rewriter.getMultiDimIdentityMap(rank)};
1164  SmallVector<Value, 4> genericInputs = {input};
1165 
1166  // If we are rescaling per-channel then we need to store the multiplier
1167  // values in a buffer.
1168  Value multiplierConstant;
1169  int64_t multiplierArg = 0;
1170  if (multiplierValues.size() == 1) {
1171  multiplierConstant = rewriter.create<arith::ConstantOp>(
1172  loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1173  } else {
1174  SmallVector<AffineExpr, 2> multiplierExprs{
1175  rewriter.getAffineDimExpr(rank - 1)};
1176  auto multiplierType =
1177  RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1178  rewriter.getI32Type());
1179  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1180  loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1181 
1182  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1183  /*symbolCount=*/0, multiplierExprs,
1184  rewriter.getContext()));
1185 
1186  multiplierArg = indexingMaps.size() - 1;
1187  }
1188 
1189  // If we are rescaling per-channel then we need to store the shift
1190  // values in a buffer.
1191  Value shiftConstant;
1192  int64_t shiftArg = 0;
1193  if (shiftValues.size() == 1) {
1194  shiftConstant = rewriter.create<arith::ConstantOp>(
1195  loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1196  } else {
1197  SmallVector<AffineExpr, 2> shiftExprs = {
1198  rewriter.getAffineDimExpr(rank - 1)};
1199  auto shiftType =
1200  RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1201  rewriter.getIntegerType(8));
1202  genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1203  loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1204  indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1205  /*symbolCount=*/0, shiftExprs,
1206  rewriter.getContext()));
1207  shiftArg = indexingMaps.size() - 1;
1208  }
1209 
1210  // Indexing maps for output values.
1211  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1212 
1213  // Construct the indexing maps needed for linalg.generic ops.
1214  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1215  loc, outputTy.getShape(), outputTy.getElementType(),
1216  ArrayRef<Value>({dynDims}));
1217 
1218  auto linalgOp = rewriter.create<linalg::GenericOp>(
1219  loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1220  getNParallelLoopsAttrs(rank),
1221  [&](OpBuilder &nestedBuilder, Location nestedLoc,
1222  ValueRange blockArgs) {
1223  Value value = blockArgs[0];
1224  Type valueTy = value.getType();
1225 
1226  // For now we do all of our math in 64-bit. This is not optimal but
1227  // should be correct for now, consider computing correct bit depth
1228  // later.
1229  int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1230 
1231  auto inputZp = createConstFromIntAttribute<int32_t>(
1232  op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1233  nestedBuilder);
1234  auto outputZp = createConstFromIntAttribute<int32_t>(
1235  op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1236 
1237  Value multiplier = multiplierConstant ? multiplierConstant
1238  : blockArgs[multiplierArg];
1239  Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1240 
1241  if (valueTy.getIntOrFloatBitWidth() < 32) {
1242  if (valueTy.isUnsignedInteger()) {
1243  value = nestedBuilder
1244  .create<UnrealizedConversionCastOp>(
1245  nestedLoc,
1246  nestedBuilder.getIntegerType(
1247  valueTy.getIntOrFloatBitWidth()),
1248  value)
1249  .getResult(0);
1250  value = nestedBuilder.create<arith::ExtUIOp>(
1251  nestedLoc, nestedBuilder.getI32Type(), value);
1252  } else {
1253  value = nestedBuilder.create<arith::ExtSIOp>(
1254  nestedLoc, nestedBuilder.getI32Type(), value);
1255  }
1256  }
1257 
1258  value =
1259  nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1260 
1261  value = nestedBuilder.create<tosa::ApplyScaleOp>(
1262  loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1263  nestedBuilder.getBoolAttr(doubleRound));
1264 
1265  // Move to the new zero-point.
1266  value =
1267  nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1268 
1269  // Saturate to the output size.
1270  IntegerType outIntType =
1271  cast<IntegerType>(blockArgs.back().getType());
1272  unsigned outBitWidth = outIntType.getWidth();
1273 
1274  int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1275  int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1276 
1277  // Unsigned integers have a difference output value.
1278  if (outIntType.isUnsignedInteger()) {
1279  intMin = 0;
1280  intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1281  }
1282 
1283  auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1284  loc, nestedBuilder.getI32IntegerAttr(intMin));
1285  auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1286  loc, nestedBuilder.getI32IntegerAttr(intMax));
1287 
1288  value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1289  nestedBuilder);
1290 
1291  if (outIntType.getWidth() < 32) {
1292  value = nestedBuilder.create<arith::TruncIOp>(
1293  nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1294  value);
1295 
1296  if (outIntType.isUnsignedInteger()) {
1297  value = nestedBuilder
1298  .create<UnrealizedConversionCastOp>(nestedLoc,
1299  outIntType, value)
1300  .getResult(0);
1301  }
1302  }
1303 
1304  nestedBuilder.create<linalg::YieldOp>(loc, value);
1305  });
1306 
1307  rewriter.replaceOp(op, linalgOp->getResults());
1308  return success();
1309  }
1310 };
1311 
1312 // Handle the resize case where the input is a 1x1 image. This case
1313 // can entirely avoiding having extract operations which target much
1314 // more difficult to optimize away.
1315 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1316 public:
1318 
1319  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1320  PatternRewriter &rewriter) const final {
1321  Location loc = op.getLoc();
1322  ImplicitLocOpBuilder builder(loc, rewriter);
1323  auto input = op.getInput();
1324  auto inputTy = cast<RankedTensorType>(input.getType());
1325  auto resultTy = cast<RankedTensorType>(op.getType());
1326  const bool isBilinear = op.getMode() == "BILINEAR";
1327 
1328  auto inputH = inputTy.getDimSize(1);
1329  auto inputW = inputTy.getDimSize(2);
1330  auto outputH = resultTy.getDimSize(1);
1331  auto outputW = resultTy.getDimSize(2);
1332 
1333  if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1334  return rewriter.notifyMatchFailure(
1335  op, "tosa.resize is not a pure 1x1->1x1 image operation");
1336 
1337  // TODO(suderman): These string values should be declared the TOSA dialect.
1338  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1339  return rewriter.notifyMatchFailure(
1340  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1341 
1342  if (inputTy == resultTy) {
1343  rewriter.replaceOp(op, input);
1344  return success();
1345  }
1346 
1347  ArrayRef<int64_t> scale = op.getScale();
1348 
1349  // Collapse the unit width and height away.
1350  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1351  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1352  reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1353  reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1354  reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1355 
1356  auto collapseTy =
1357  RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1358  inputTy.getElementType());
1359  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1360  reassociationMap);
1361 
1362  // Get any dynamic shapes that appear in the input format.
1363  llvm::SmallVector<Value> outputDynSize;
1364  if (inputTy.isDynamicDim(0))
1365  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1366  if (inputTy.isDynamicDim(3))
1367  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1368 
1369  // Generate the elementwise operation for casting scaling the input value.
1370  auto genericTy = collapseTy.clone(resultTy.getElementType());
1371  Value empty = builder.create<tensor::EmptyOp>(
1372  genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1373  auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1374  SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1375  utils::IteratorType::parallel);
1376 
1377  auto generic = builder.create<linalg::GenericOp>(
1378  genericTy, ValueRange{collapse}, ValueRange{empty},
1379  ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1380  [=](OpBuilder &b, Location loc, ValueRange args) {
1381  Value value = args[0];
1382  // This is the quantized case.
1383  if (inputTy.getElementType() != resultTy.getElementType()) {
1384  value =
1385  b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1386 
1387  if (isBilinear && scale[0] != 0) {
1388  Value scaleY = b.create<arith::ConstantOp>(
1389  loc, b.getI32IntegerAttr(scale[0]));
1390  value = b.create<arith::MulIOp>(loc, value, scaleY);
1391  }
1392 
1393  if (isBilinear && scale[2] != 0) {
1394  Value scaleX = b.create<arith::ConstantOp>(
1395  loc, b.getI32IntegerAttr(scale[2]));
1396  value = b.create<arith::MulIOp>(loc, value, scaleX);
1397  }
1398  }
1399 
1400  b.create<linalg::YieldOp>(loc, value);
1401  });
1402 
1403  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1404  op, resultTy, generic.getResults()[0], reassociationMap);
1405  return success();
1406  }
1407 };
1408 
1409 // TOSA resize with width or height of 1 may be broadcasted to a wider
1410 // dimension. This is done by materializing a new tosa.resize without
1411 // the broadcasting behavior, and an explicit broadcast afterwards.
1412 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1413 public:
1415 
1416  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1417  PatternRewriter &rewriter) const final {
1418  Location loc = op.getLoc();
1419  ImplicitLocOpBuilder builder(loc, rewriter);
1420  auto input = op.getInput();
1421  auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1422  auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1423 
1424  if (!inputTy || !resultTy)
1425  return rewriter.notifyMatchFailure(op,
1426  "requires ranked input/output types");
1427 
1428  auto batch = inputTy.getDimSize(0);
1429  auto channels = inputTy.getDimSize(3);
1430  auto inputH = inputTy.getDimSize(1);
1431  auto inputW = inputTy.getDimSize(2);
1432  auto outputH = resultTy.getDimSize(1);
1433  auto outputW = resultTy.getDimSize(2);
1434 
1435  if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1436  return rewriter.notifyMatchFailure(
1437  op, "tosa.resize has no broadcasting behavior");
1438 
1439  // For any dimension that is broadcastable we generate a width of 1
1440  // on the output.
1441  llvm::SmallVector<int64_t> resizeShape;
1442  resizeShape.push_back(batch);
1443  resizeShape.push_back(inputH == 1 ? 1 : outputH);
1444  resizeShape.push_back(inputW == 1 ? 1 : outputW);
1445  resizeShape.push_back(channels);
1446 
1447  auto resizeTy = resultTy.clone(resizeShape);
1448  auto resize =
1449  builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1450 
1451  // Collapse an unit result dims.
1452  SmallVector<ReassociationExprs, 4> reassociationMap(2);
1453  reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1454  reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1455  if (inputH != 1)
1456  reassociationMap.push_back({});
1457  reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1458  if (inputW != 1)
1459  reassociationMap.push_back({});
1460  reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1461 
1462  llvm::SmallVector<int64_t> collapseShape{batch};
1463  if (inputH != 1)
1464  collapseShape.push_back(outputH);
1465  if (inputW != 1)
1466  collapseShape.push_back(outputW);
1467  collapseShape.push_back(channels);
1468 
1469  auto collapseTy = resultTy.clone(collapseShape);
1470  Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1471  reassociationMap);
1472 
1473  // Broadcast the collapsed shape to the output result.
1474  llvm::SmallVector<Value> outputDynSize;
1475  if (inputTy.isDynamicDim(0))
1476  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1477  if (inputTy.isDynamicDim(3))
1478  outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1479 
1480  SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1481  utils::IteratorType::parallel);
1482  Value empty = builder.create<tensor::EmptyOp>(
1483  resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1484 
1485  SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1486  if (inputH != 1)
1487  inputExprs.push_back(rewriter.getAffineDimExpr(1));
1488  if (inputW != 1)
1489  inputExprs.push_back(rewriter.getAffineDimExpr(2));
1490  inputExprs.push_back(rewriter.getAffineDimExpr(3));
1491 
1492  auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1493  inputExprs, rewriter.getContext());
1494 
1495  auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1496  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1497  op, resultTy, ValueRange{collapse}, ValueRange{empty},
1498  ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1499  [=](OpBuilder &b, Location loc, ValueRange args) {
1500  Value value = args[0];
1501  b.create<linalg::YieldOp>(loc, value);
1502  });
1503 
1504  return success();
1505  }
1506 };
1507 
1508 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1509 public:
1511 
1512  LogicalResult matchAndRewrite(tosa::ResizeOp op,
1513  PatternRewriter &rewriter) const final {
1514  Location loc = op.getLoc();
1515  ImplicitLocOpBuilder b(loc, rewriter);
1516  auto input = op.getInput();
1517  auto inputTy = cast<ShapedType>(input.getType());
1518  auto resultTy = cast<ShapedType>(op.getType());
1519  auto resultETy = resultTy.getElementType();
1520 
1521  auto imageH = inputTy.getShape()[1];
1522  auto imageW = inputTy.getShape()[2];
1523 
1524  auto dynamicDimsOr =
1525  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1526  if (!dynamicDimsOr.has_value())
1527  return rewriter.notifyMatchFailure(
1528  op, "unable to get dynamic dimensions of tosa.resize");
1529 
1530  if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1531  return rewriter.notifyMatchFailure(
1532  op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1533 
1534  SmallVector<AffineMap, 2> affineMaps = {
1535  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1536  auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1537  *dynamicDimsOr);
1538  auto genericOp = b.create<linalg::GenericOp>(
1539  resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1540  getNParallelLoopsAttrs(resultTy.getRank()));
1541  Value resize = genericOp.getResult(0);
1542 
1543  {
1544  OpBuilder::InsertionGuard regionGuard(b);
1545  b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1546  TypeRange({resultETy}), loc);
1547  Value batch = b.create<linalg::IndexOp>(0);
1548  Value y = b.create<linalg::IndexOp>(1);
1549  Value x = b.create<linalg::IndexOp>(2);
1550  Value channel = b.create<linalg::IndexOp>(3);
1551 
1552  Value zeroI32 =
1553  b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1554  Value zeroFp32 =
1555  b.create<arith::ConstantOp>(b.getZeroAttr(b.getF32Type()));
1556  Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1557  Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1558 
1559  Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1560  Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1561 
1562  bool floatingPointMode = resultETy.isF32();
1563 
1564  ArrayRef<int64_t> offset = op.getOffset();
1565  ArrayRef<int64_t> border = op.getBorder();
1566  ArrayRef<int64_t> scale = op.getScale();
1567 
1568  Value yScaleN, yScaleD, xScaleN, xScaleD;
1569  yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1570  yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1571  xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1572  xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1573 
1574  Value yOffset, xOffset, yBorder, xBorder;
1575  yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1576  xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1577  yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1578  xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1579 
1580  // Compute the ix and dx values for both the X and Y dimensions.
1581  auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1582  Value scaleN, Value scaleD, Value offset,
1583  int size, ImplicitLocOpBuilder &b) {
1584  if (size == 1) {
1585  index = zeroI32;
1586  delta = zeroFp32;
1587  return;
1588  }
1589  // x = x * scale_d + offset;
1590  // ix = floor(x / scale_n)
1591  // dx = x / scale_n - ix
1592  Value val = b.create<arith::UIToFPOp>(b.getF32Type(), in);
1593  scaleN = b.create<arith::UIToFPOp>(b.getF32Type(), scaleN);
1594  scaleD = b.create<arith::UIToFPOp>(b.getF32Type(), scaleD);
1595  offset = b.create<arith::SIToFPOp>(b.getF32Type(), offset);
1596  val = b.create<arith::MulFOp>(val, scaleD);
1597  val = b.create<arith::AddFOp>(val, offset);
1598  val = b.create<arith::DivFOp>(val, scaleN);
1599  index = b.create<math::FloorOp>(val);
1600  delta = b.create<arith::SubFOp>(val, index);
1601  index = b.create<arith::FPToSIOp>(b.getI32Type(), index);
1602  };
1603 
1604  // Compute the ix and dx values for the X and Y dimensions - int case.
1605  auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1606  Value scaleN, Value scaleD, Value offset,
1607  int size, ImplicitLocOpBuilder &b) {
1608  if (size == 1) {
1609  index = zeroI32;
1610  delta = zeroI32;
1611  return;
1612  }
1613  // x = x * scale_d + offset;
1614  // ix = floor(x / scale_n)
1615  // dx = x - ix * scale_n;
1616  Value val = b.create<arith::MulIOp>(in, scaleD);
1617  val = b.create<arith::AddIOp>(val, offset);
1618  index = b.create<arith::DivSIOp>(val, scaleN);
1619  delta = b.create<arith::MulIOp>(index, scaleN);
1620  delta = b.create<arith::SubIOp>(val, delta);
1621  };
1622 
1623  Value ix, iy, dx, dy;
1624  if (floatingPointMode) {
1625  getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1626  getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1627  } else {
1628  getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1629  getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1630  }
1631 
1632  if (op.getMode() == "NEAREST_NEIGHBOR") {
1633  auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1634 
1635  auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1636  Value max, int size,
1637  ImplicitLocOpBuilder &b) -> Value {
1638  if (size == 1) {
1639  return b.create<arith::ConstantIndexOp>(0);
1640  }
1641 
1642  Value pred;
1643  if (floatingPointMode) {
1644  auto h = b.create<arith::ConstantOp>(b.getF32FloatAttr(0.5f));
1645  pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1646  } else {
1647  Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1648  pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1649  dvalDouble, scale);
1650  }
1651 
1652  auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1653  val = b.create<arith::AddIOp>(val, offset);
1654  val = clampIntHelper(loc, val, zeroI32, max, b);
1655  return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1656  };
1657 
1658  iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1659  ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1660 
1661  Value result = b.create<tensor::ExtractOp>(
1662  input, ValueRange{batch, iy, ix, channel});
1663 
1664  b.create<linalg::YieldOp>(result);
1665  } else {
1666  // The mode here must be BILINEAR.
1667  assert(op.getMode() == "BILINEAR");
1668 
1669  auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1670 
1671  auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1673  val0 = in;
1674  val1 = b.create<arith::AddIOp>(val0, oneVal);
1675  val0 = clampIntHelper(loc, val0, zeroI32, max, b);
1676  val1 = clampIntHelper(loc, val1, zeroI32, max, b);
1677  val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1678  val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1679  };
1680 
1681  // Linalg equivalent to the section below:
1682  // int16_t iy0 = apply_max(iy, 0);
1683  // int16_t iy1 = apply_min(iy + 1, IH - 1);
1684  // int16_t ix0 = apply_max(ix, 0);
1685  // int16_t ix1 = apply_min(ix + 1, IW - 1);
1686  Value x0, x1, y0, y1;
1687  getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1688  getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1689 
1690  Value y0x0 = b.create<tensor::ExtractOp>(
1691  input, ValueRange{batch, y0, x0, channel});
1692  Value y0x1 = b.create<tensor::ExtractOp>(
1693  input, ValueRange{batch, y0, x1, channel});
1694  Value y1x0 = b.create<tensor::ExtractOp>(
1695  input, ValueRange{batch, y1, x0, channel});
1696  Value y1x1 = b.create<tensor::ExtractOp>(
1697  input, ValueRange{batch, y1, x1, channel});
1698 
1699  if (floatingPointMode) {
1700  auto oneVal = b.create<arith::ConstantOp>(b.getF32FloatAttr(1.0f));
1701  auto interpolate = [&](Value val0, Value val1, Value delta,
1702  int inputSize,
1703  ImplicitLocOpBuilder &b) -> Value {
1704  if (inputSize == 1)
1705  return val0;
1706  Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1707  Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1708  Value mul1 = b.create<arith::MulFOp>(val1, delta);
1709  return b.create<arith::AddFOp>(mul0, mul1);
1710  };
1711 
1712  // Linalg equivalent to the section below:
1713  // topAcc = v00 * (unit_x - dx);
1714  // topAcc += v01 * dx;
1715  Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1716 
1717  // Linalg equivalent to the section below:
1718  // bottomAcc = v10 * (unit_x - dx);
1719  // bottomAcc += v11 * dx;
1720  Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1721 
1722  // Linalg equivalent to the section below:
1723  // result = topAcc * (unit_y - dy) + bottomAcc * dy
1724  Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1725  b.create<linalg::YieldOp>(result);
1726  } else {
1727  // Perform in quantized space.
1728  y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1729  y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1730  y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1731  y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1732 
1733  const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1734  if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1735  dx = b.create<arith::ExtSIOp>(resultETy, dx);
1736  dy = b.create<arith::ExtSIOp>(resultETy, dy);
1737  }
1738 
1739  Value yScaleNExt = yScaleN;
1740  Value xScaleNExt = xScaleN;
1741 
1742  const int64_t scaleBitwidth =
1743  xScaleN.getType().getIntOrFloatBitWidth();
1744  if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1745  yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1746  xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1747  }
1748 
1749  auto interpolate = [](Value val0, Value val1, Value weight1,
1750  Value scale, int inputSize,
1751  ImplicitLocOpBuilder &b) -> Value {
1752  if (inputSize == 1)
1753  return b.create<arith::MulIOp>(val0, scale);
1754  Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1755  Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1756  Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1757  return b.create<arith::AddIOp>(mul0, mul1);
1758  };
1759 
1760  Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1761  Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1762  Value result =
1763  interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1764  b.create<linalg::YieldOp>(result);
1765  }
1766  }
1767  }
1768 
1769  rewriter.replaceOp(op, resize);
1770  return success();
1771  }
1772 };
1773 
1774 // At the codegen level any identity operations should be removed. Any cases
1775 // where identity is load-bearing (e.g. cross device computation) should be
1776 // handled before lowering to codegen.
1777 template <typename SrcOp>
1778 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1779 public:
1781 
1782  LogicalResult matchAndRewrite(SrcOp op,
1783  PatternRewriter &rewriter) const final {
1784  rewriter.replaceOp(op, op.getOperation()->getOperands());
1785  return success();
1786  }
1787 };
1788 
1789 template <typename SrcOp>
1790 class ReduceConverter : public OpRewritePattern<SrcOp> {
1791 public:
1793 
1794  LogicalResult matchAndRewrite(SrcOp reduceOp,
1795  PatternRewriter &rewriter) const final {
1796  return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1797  }
1798 };
1799 
1800 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1801 public:
1803 
1804  LogicalResult matchAndRewrite(tosa::ReverseOp op,
1805  PatternRewriter &rewriter) const final {
1806  auto loc = op.getLoc();
1807  Value input = op.getInput();
1808  auto inputTy = cast<ShapedType>(input.getType());
1809  auto resultTy = cast<ShapedType>(op.getType());
1810  auto axis = op.getAxis();
1811 
1812  SmallVector<Value> dynDims;
1813  for (int i = 0; i < inputTy.getRank(); i++) {
1814  if (inputTy.isDynamicDim(i)) {
1815  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1816  }
1817  }
1818 
1819  Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1820 
1821  // First fill the output buffer with the init value.
1822  auto emptyTensor = rewriter
1823  .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1824  inputTy.getElementType(),
1825  ArrayRef<Value>({dynDims}))
1826  .getResult();
1827  SmallVector<AffineMap, 2> affineMaps = {
1828  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1829 
1830  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1831  op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
1832  getNParallelLoopsAttrs(resultTy.getRank()),
1833  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1834  llvm::SmallVector<Value> indices;
1835  for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1836  Value index =
1837  rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1838  if (i == axis) {
1839  auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1840  auto sizeMinusOne =
1841  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1842  index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1843  index);
1844  }
1845 
1846  indices.push_back(index);
1847  }
1848 
1849  auto extract = nestedBuilder.create<tensor::ExtractOp>(
1850  nestedLoc, input, indices);
1851  nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1852  extract.getResult());
1853  });
1854  return success();
1855  }
1856 };
1857 
1858 // This converter translate a tile operation to a reshape, broadcast, reshape.
1859 // The first reshape minimally expands each tiled dimension to include a
1860 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1861 // multiple.
1862 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1864 
1866  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1867  ConversionPatternRewriter &rewriter) const override {
1868  auto loc = op.getLoc();
1869  auto input = op.getInput1();
1870  auto inputTy = cast<ShapedType>(input.getType());
1871  auto inputShape = inputTy.getShape();
1872  auto resultTy = cast<ShapedType>(op.getType());
1873  auto elementTy = inputTy.getElementType();
1874  int64_t rank = inputTy.getRank();
1875 
1876  ArrayRef<int64_t> multiples = op.getMultiples();
1877 
1878  // Broadcast the newly added dimensions to their appropriate multiple.
1879  SmallVector<int64_t, 2> genericShape;
1880  for (int i = 0; i < rank; i++) {
1881  int64_t dim = multiples[i];
1882  genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1883  genericShape.push_back(inputShape[i]);
1884  }
1885 
1886  SmallVector<Value> dynDims;
1887  for (int i = 0; i < inputTy.getRank(); i++) {
1888  if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1889  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1890  }
1891  }
1892 
1893  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1894  op.getLoc(), genericShape, elementTy, dynDims);
1895 
1896  // We needs to map the input shape to the non-broadcasted dimensions.
1897  SmallVector<AffineExpr, 4> dimExprs;
1898  dimExprs.reserve(rank);
1899  for (unsigned i = 0; i < rank; ++i)
1900  dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1901 
1902  auto readAffineMap =
1903  AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1904  rewriter.getContext());
1905 
1906  SmallVector<AffineMap, 2> affineMaps = {
1907  readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1908 
1909  auto genericOp = rewriter.create<linalg::GenericOp>(
1910  loc, RankedTensorType::get(genericShape, elementTy), input,
1911  ValueRange{emptyTensor}, affineMaps,
1912  getNParallelLoopsAttrs(genericShape.size()),
1913  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1914  nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1915  });
1916 
1917  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1918  op, resultTy, genericOp.getResult(0),
1919  rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
1920  return success();
1921  }
1922 };
1923 
1924 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1925 // op, producing two output buffers.
1926 //
1927 // The first output buffer contains the index of the found maximum value. It is
1928 // initialized to 0 and is resulting integer type.
1929 //
1930 // The second output buffer contains the maximum value found. It is initialized
1931 // to the minimum representable value of the input element type. After being
1932 // populated by indexed_generic, this buffer is disgarded as only the index is
1933 // requested.
1934 //
1935 // The indexed_generic op updates both the maximum value and index if the
1936 // current value exceeds the running max.
1937 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1938 public:
1940 
1941  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1942  PatternRewriter &rewriter) const final {
1943  auto loc = argmaxOp.getLoc();
1944  Value input = argmaxOp.getInput();
1945  auto inputTy = cast<ShapedType>(input.getType());
1946  auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1947  auto inElementTy = inputTy.getElementType();
1948  auto outElementTy = resultTy.getElementType();
1949  int axis = argmaxOp.getAxis();
1950  auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1951 
1952  if (!isa<IntegerType>(outElementTy))
1953  return rewriter.notifyMatchFailure(
1954  argmaxOp,
1955  "tosa.arg_max to linalg.* requires integer-like result type");
1956 
1957  SmallVector<Value> dynDims;
1958  for (int i = 0; i < inputTy.getRank(); i++) {
1959  if (inputTy.isDynamicDim(i) && i != axis) {
1960  dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1961  }
1962  }
1963 
1964  // First fill the output buffer for the index.
1965  auto emptyTensorIdx = rewriter
1966  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1967  outElementTy, dynDims)
1968  .getResult();
1969  auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1970  loc, rewriter.getIntegerAttr(outElementTy, 0));
1971  auto filledTensorIdx =
1972  rewriter
1973  .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
1974  ValueRange{emptyTensorIdx})
1975  .result();
1976 
1977  // Second fill the output buffer for the running max.
1978  auto emptyTensorMax = rewriter
1979  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1980  inElementTy, dynDims)
1981  .getResult();
1982  auto fillValueMaxAttr =
1983  createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
1984 
1985  if (!fillValueMaxAttr)
1986  return rewriter.notifyMatchFailure(
1987  argmaxOp, "unsupported tosa.argmax element type");
1988 
1989  auto fillValueMax =
1990  rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
1991  auto filledTensorMax =
1992  rewriter
1993  .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
1994  ValueRange{emptyTensorMax})
1995  .result();
1996 
1997  // We need to reduce along the arg-max axis, with parallel operations along
1998  // the rest.
2000  iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2001  iteratorTypes[axis] = utils::IteratorType::reduction;
2002 
2003  SmallVector<AffineExpr, 2> srcExprs;
2004  SmallVector<AffineExpr, 2> dstExprs;
2005  for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2006  srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2007  if (axis != i)
2008  dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2009  }
2010 
2011  bool didEncounterError = false;
2012  auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
2013  auto linalgOp = rewriter.create<linalg::GenericOp>(
2014  loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2015  ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2016  [&](OpBuilder &nestedBuilder, Location nestedLoc,
2017  ValueRange blockArgs) {
2018  auto newValue = blockArgs[0];
2019  auto oldIndex = blockArgs[1];
2020  auto oldValue = blockArgs[2];
2021 
2022  Value newIndex = rewriter.create<arith::IndexCastOp>(
2023  nestedLoc, oldIndex.getType(),
2024  rewriter.create<linalg::IndexOp>(loc, axis));
2025 
2026  Value predicate;
2027  if (isa<FloatType>(inElementTy)) {
2028  predicate = rewriter.create<arith::CmpFOp>(
2029  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2030  } else if (isa<IntegerType>(inElementTy)) {
2031  predicate = rewriter.create<arith::CmpIOp>(
2032  nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2033  } else {
2034  didEncounterError = true;
2035  return;
2036  }
2037 
2038  auto resultMax = rewriter.create<arith::SelectOp>(
2039  nestedLoc, predicate, newValue, oldValue);
2040  auto resultIndex = rewriter.create<arith::SelectOp>(
2041  nestedLoc, predicate, newIndex, oldIndex);
2042  nestedBuilder.create<linalg::YieldOp>(
2043  nestedLoc, ValueRange({resultIndex, resultMax}));
2044  });
2045 
2046  if (didEncounterError)
2047  return rewriter.notifyMatchFailure(
2048  argmaxOp, "unsupported tosa.argmax element type");
2049 
2050  rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2051  return success();
2052  }
2053 };
2054 
2055 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2056 public:
2059  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2060  ConversionPatternRewriter &rewriter) const final {
2061  auto input = adaptor.getOperands()[0];
2062  auto indices = adaptor.getOperands()[1];
2063 
2064  auto valuesTy =
2065  dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2066  auto resultTy = cast<ShapedType>(op.getType());
2067 
2068  if (!valuesTy)
2069  return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2070 
2071  auto dynamicDims = inferDynamicDimsForGather(
2072  rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2073 
2074  auto resultElementTy = resultTy.getElementType();
2075 
2076  auto loc = op.getLoc();
2077  auto emptyTensor =
2078  rewriter
2079  .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2080  dynamicDims)
2081  .getResult();
2082 
2083  SmallVector<AffineMap, 2> affineMaps = {
2085  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2086  {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2087  rewriter.getContext()),
2088  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2089 
2090  auto genericOp = rewriter.create<linalg::GenericOp>(
2091  loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2092  ValueRange{emptyTensor}, affineMaps,
2093  getNParallelLoopsAttrs(resultTy.getRank()),
2094  [&](OpBuilder &b, Location loc, ValueRange args) {
2095  auto indexValue = args[0];
2096  auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2097  Value index1 = rewriter.create<arith::IndexCastOp>(
2098  loc, rewriter.getIndexType(), indexValue);
2099  auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2100  Value extract = rewriter.create<tensor::ExtractOp>(
2101  loc, input, ValueRange{index0, index1, index2});
2102  rewriter.create<linalg::YieldOp>(loc, extract);
2103  });
2104  rewriter.replaceOp(op, genericOp.getResult(0));
2105  return success();
2106  }
2107 
2108  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2109  Location loc,
2110  Value values,
2111  Value indices) {
2112  llvm::SmallVector<Value> results;
2113 
2114  auto addDynamicDimension = [&](Value source, int64_t dim) {
2115  auto sz = tensor::getMixedSize(builder, loc, source, dim);
2116  if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2117  results.push_back(dimValue);
2118  };
2119 
2120  addDynamicDimension(values, 0);
2121  addDynamicDimension(indices, 1);
2122  addDynamicDimension(values, 2);
2123  return results;
2124  }
2125 };
2126 
2127 // Lowerings the TableOp to a series of gathers and numerica operations. This
2128 // includes interpolation between the high/low values. For the I8 varient, this
2129 // simplifies to a single gather operation.
2130 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2131 public:
2133 
2134  LogicalResult matchAndRewrite(tosa::TableOp op,
2135  PatternRewriter &rewriter) const final {
2136  auto loc = op.getLoc();
2137  Value input = op.getInput();
2138  Value table = op.getTable();
2139  auto inputTy = cast<ShapedType>(input.getType());
2140  auto tableTy = cast<ShapedType>(table.getType());
2141  auto resultTy = cast<ShapedType>(op.getType());
2142 
2143  auto inputElementTy = inputTy.getElementType();
2144  auto tableElementTy = tableTy.getElementType();
2145  auto resultElementTy = resultTy.getElementType();
2146 
2147  SmallVector<Value> dynDims;
2148  for (int i = 0; i < resultTy.getRank(); ++i) {
2149  if (inputTy.isDynamicDim(i)) {
2150  dynDims.push_back(
2151  rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2152  }
2153  }
2154 
2155  auto emptyTensor = rewriter
2156  .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2157  resultElementTy, dynDims)
2158  .getResult();
2159 
2160  SmallVector<AffineMap, 2> affineMaps = {
2161  rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2162  rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2163 
2164  auto genericOp = rewriter.create<linalg::GenericOp>(
2165  loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2166  getNParallelLoopsAttrs(resultTy.getRank()));
2167  rewriter.replaceOp(op, genericOp.getResult(0));
2168 
2169  {
2170  OpBuilder::InsertionGuard regionGuard(rewriter);
2171  Block *block = rewriter.createBlock(
2172  &genericOp.getRegion(), genericOp.getRegion().end(),
2173  TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2174 
2175  auto inputValue = block->getArgument(0);
2176  rewriter.setInsertionPointToStart(block);
2177  if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2178  resultElementTy.isInteger(8)) {
2179  Value index = rewriter.create<arith::IndexCastOp>(
2180  loc, rewriter.getIndexType(), inputValue);
2181  Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2182  index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2183  index, offset);
2184  Value extract =
2185  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2186  rewriter.create<linalg::YieldOp>(loc, extract);
2187  return success();
2188  }
2189 
2190  if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2191  resultElementTy.isInteger(32)) {
2192  Value extend = rewriter.create<arith::ExtSIOp>(
2193  loc, rewriter.getI32Type(), inputValue);
2194 
2195  auto offset = rewriter.create<arith::ConstantOp>(
2196  loc, rewriter.getI32IntegerAttr(32768));
2197  auto seven = rewriter.create<arith::ConstantOp>(
2198  loc, rewriter.getI32IntegerAttr(7));
2199  auto one = rewriter.create<arith::ConstantOp>(
2200  loc, rewriter.getI32IntegerAttr(1));
2201  auto b1111111 = rewriter.create<arith::ConstantOp>(
2202  loc, rewriter.getI32IntegerAttr(127));
2203 
2204  // Compute the index and fractional part from the input value:
2205  // value = value + 32768
2206  // index = value >> 7;
2207  // fraction = 0x01111111 & value
2208  auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2209  Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2210  Value fraction =
2211  rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2212 
2213  // Extract the base and next values from the table.
2214  // base = (int32_t) table[index];
2215  // next = (int32_t) table[index + 1];
2216  Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2217 
2218  index = rewriter.create<arith::IndexCastOp>(
2219  loc, rewriter.getIndexType(), index);
2220  indexPlusOne = rewriter.create<arith::IndexCastOp>(
2221  loc, rewriter.getIndexType(), indexPlusOne);
2222 
2223  Value base =
2224  rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2225  Value next = rewriter.create<tensor::ExtractOp>(
2226  loc, table, ValueRange{indexPlusOne});
2227 
2228  base =
2229  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2230  next =
2231  rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2232 
2233  // Use the fractional part to interpolate between the input values:
2234  // result = (base << 7) + (next - base) * fraction
2235  Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2236  Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2237  Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2238  Value result =
2239  rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2240 
2241  rewriter.create<linalg::YieldOp>(loc, result);
2242 
2243  return success();
2244  }
2245  }
2246 
2247  return rewriter.notifyMatchFailure(
2248  op, "unable to create body for tosa.table op");
2249  }
2250 };
2251 
2252 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2254 
2255  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2256 
2257  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2258  OpFoldResult ofr) {
2259  auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2260  auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2261 
2262  auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2263  auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2264  auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2265  return getAsOpFoldResult(plusOne);
2266  }
2267 
2268  static RankedTensorType
2269  computeOutputShape(OpBuilder &builder, Location loc, Value input,
2270  llvm::SmallVectorImpl<Value> &dynamicSizes) {
2271  // Get [N, H, W]
2272  auto dims = tensor::getMixedSizes(builder, loc, input);
2273 
2274  // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2275  // output tensors.
2276  dims[2] = halfPlusOne(builder, loc, dims[2]);
2277 
2278  llvm::SmallVector<int64_t, 3> staticSizes;
2279  dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2280 
2281  auto elementType =
2282  input.getType().cast<RankedTensorType>().getElementType();
2283  return RankedTensorType::get(staticSizes, elementType);
2284  }
2285 
2286  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2287  RankedTensorType type,
2288  llvm::ArrayRef<Value> dynamicSizes) {
2289  auto emptyTensor =
2290  rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2291  auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2292  auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2293  auto filledTensor = rewriter
2294  .create<linalg::FillOp>(loc, ValueRange{fillValue},
2295  ValueRange{emptyTensor})
2296  .result();
2297  return filledTensor;
2298  }
2299 
2300  static Value castIndexToFloat(OpBuilder &builder, Location loc,
2301  FloatType type, Value value) {
2302  auto integerVal =
2303  builder.create<arith::IndexCastUIOp>(loc, builder.getI64Type(), value);
2304 
2305  return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2306  }
2307 
2308  static Value createLinalgIndex(OpBuilder &builder, Location loc,
2309  FloatType type, int64_t index) {
2310  auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2311  return castIndexToFloat(builder, loc, type, indexVal);
2312  }
2313 
2314  template <typename... Args>
2315  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2316  Args... args) {
2317  return {builder.getAffineDimExpr(args)...};
2318  }
2319 
2320  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2321  PatternRewriter &rewriter) const override {
2322  if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2323  !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2324  return rewriter.notifyMatchFailure(rfft2d,
2325  "only supports ranked tensors");
2326  }
2327 
2328  auto loc = rfft2d.getLoc();
2329  auto input = rfft2d.getInput();
2330  auto elementType =
2331  input.getType().cast<ShapedType>().getElementType().cast<FloatType>();
2332 
2333  // Compute the output type and set of dynamic sizes
2334  llvm::SmallVector<Value> dynamicSizes;
2335  auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2336 
2337  // Iterator types for the linalg.generic implementation
2339  utils::IteratorType::parallel, utils::IteratorType::parallel,
2340  utils::IteratorType::parallel, utils::IteratorType::reduction,
2341  utils::IteratorType::reduction};
2342 
2343  // Inputs/outputs to the linalg.generic implementation
2344  llvm::SmallVector<Value> genericOpInputs = {input};
2345  llvm::SmallVector<Value> genericOpOutputs = {
2346  createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2347  createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2348 
2349  // Indexing maps for input and output tensors
2350  auto indexingMaps = AffineMap::inferFromExprList(llvm::ArrayRef{
2351  affineDimsExpr(rewriter, 0, 3, 4), affineDimsExpr(rewriter, 0, 1, 2),
2352  affineDimsExpr(rewriter, 0, 1, 2)});
2353 
2354  // Width and height dimensions of the original input.
2355  auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2356  auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2357 
2358  // Constants and dimension sizes
2359  auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2360  auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2361  auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2362  auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2363 
2364  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2365  Value valReal = args[0];
2366  Value sumReal = args[1];
2367  Value sumImag = args[2];
2368 
2369  // Indices for angle computation
2370  auto oy = createLinalgIndex(builder, loc, elementType, 1);
2371  auto ox = createLinalgIndex(builder, loc, elementType, 2);
2372  auto iy = createLinalgIndex(builder, loc, elementType, 3);
2373  auto ix = createLinalgIndex(builder, loc, elementType, 4);
2374 
2375  // angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W)
2376  auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
2377  auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
2378  auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
2379  auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
2380  auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2381  auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2382 
2383  // realComponent = valReal * cos(angle)
2384  // imagComponent = valReal * sin(angle)
2385  auto cosAngle = builder.create<math::CosOp>(loc, angle);
2386  auto sinAngle = builder.create<math::SinOp>(loc, angle);
2387  auto realComponent =
2388  builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2389  auto imagComponent =
2390  builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2391 
2392  // outReal = sumReal + realComponent
2393  // outImag = sumImag - imagComponent
2394  auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2395  auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2396 
2397  builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2398  };
2399 
2400  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2401  rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2402  indexingMaps, iteratorTypes, buildBody);
2403 
2404  return success();
2405  }
2406 };
2407 
2408 } // namespace
2409 
2411  RewritePatternSet *patterns) {
2412 
2413  // We have multiple resize coverters to handle degenerate cases.
2414  patterns->add<GenericResizeConverter>(patterns->getContext(),
2415  /*benefit=*/100);
2416  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2417  /*benefit=*/200);
2418  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2419  /*benefit=*/300);
2420 
2421  patterns->add<
2422  // clang-format off
2423  PointwiseConverter<tosa::AddOp>,
2424  PointwiseConverter<tosa::SubOp>,
2425  PointwiseConverter<tosa::MulOp>,
2426  PointwiseConverter<tosa::DivOp>,
2427  PointwiseConverter<tosa::NegateOp>,
2428  PointwiseConverter<tosa::PowOp>,
2429  PointwiseConverter<tosa::ReciprocalOp>,
2430  PointwiseConverter<tosa::RsqrtOp>,
2431  PointwiseConverter<tosa::LogOp>,
2432  PointwiseConverter<tosa::ExpOp>,
2433  PointwiseConverter<tosa::AbsOp>,
2434  PointwiseConverter<tosa::TanhOp>,
2435  PointwiseConverter<tosa::ErfOp>,
2436  PointwiseConverter<tosa::BitwiseAndOp>,
2437  PointwiseConverter<tosa::BitwiseOrOp>,
2438  PointwiseConverter<tosa::BitwiseNotOp>,
2439  PointwiseConverter<tosa::BitwiseXorOp>,
2440  PointwiseConverter<tosa::LogicalAndOp>,
2441  PointwiseConverter<tosa::LogicalNotOp>,
2442  PointwiseConverter<tosa::LogicalOrOp>,
2443  PointwiseConverter<tosa::LogicalXorOp>,
2444  PointwiseConverter<tosa::CastOp>,
2445  PointwiseConverter<tosa::LogicalLeftShiftOp>,
2446  PointwiseConverter<tosa::LogicalRightShiftOp>,
2447  PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2448  PointwiseConverter<tosa::ClzOp>,
2449  PointwiseConverter<tosa::SelectOp>,
2450  PointwiseConverter<tosa::GreaterOp>,
2451  PointwiseConverter<tosa::GreaterEqualOp>,
2452  PointwiseConverter<tosa::EqualOp>,
2453  PointwiseConverter<tosa::MaximumOp>,
2454  PointwiseConverter<tosa::MinimumOp>,
2455  PointwiseConverter<tosa::CeilOp>,
2456  PointwiseConverter<tosa::FloorOp>,
2457  PointwiseConverter<tosa::ClampOp>,
2458  PointwiseConverter<tosa::SigmoidOp>,
2459  IdentityNConverter<tosa::IdentityOp>,
2460  ReduceConverter<tosa::ReduceAllOp>,
2461  ReduceConverter<tosa::ReduceAnyOp>,
2462  ReduceConverter<tosa::ReduceMinOp>,
2463  ReduceConverter<tosa::ReduceMaxOp>,
2464  ReduceConverter<tosa::ReduceSumOp>,
2465  ReduceConverter<tosa::ReduceProdOp>,
2466  ArgMaxConverter,
2467  GatherConverter,
2468  RescaleConverter,
2469  ReverseConverter,
2470  RFFT2dConverter,
2471  TableConverter,
2472  TileConverter,
2473  TransposeConverter>(patterns->getContext());
2474  // clang-format on
2475 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, PatternRewriter &rewriter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static LogicalResult emitElementwiseComputation(PatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, Operation *operation)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:275
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:255
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:376
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:361
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:339
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:89
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
type_range getType() const
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:50
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:60
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357