MLIR 22.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/Sequence.h"
31
32#include <type_traits>
33
34using namespace mlir;
35using namespace mlir::tosa;
36
37// Helper function to materialize the semantically correct compare and select
38// operations given a binary operation with a specific NaN propagation mode.
39//
40// In the case of "PROPAGATE" semantics no compare and selection is required and
41// this function does nothing.
42//
43// In the case of "IGNORE" semantics this function materializes a comparison of
44// the current operands to the op which will return true for any NaN
45// argument and then selects between the non-NaN operation argument and the
46// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
47// code:
48//
49// In the case that the op is operating on non floating point types we ignore
50// the attribute completely, this is consistent with the TOSA spec which has
51// the following wording: "This attribute is ignored by non floating-point
52// types."
53//
54// binary<op>(lhs, rhs):
55// result = op(lhs, rhs)
56// if lhs == NaN return rhs
57// if rhs == NaN return lhs
58// return result
59template <typename OpTy>
60static Value
63 // NaN propagation has no meaning for non floating point types.
64 if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
65 return result;
66
67 auto nanMode = op.getNanMode();
68 if (nanMode == NanPropagationMode::PROPAGATE)
69 return result;
70
71 // Unordered comparison of NaN against itself will always return true.
72 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
73 arith::CmpFPredicate::UNO, lhs, lhs);
74 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
75 arith::CmpFPredicate::UNO, rhs, rhs);
76 Value rhsOrResult =
77 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
78 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
79 rhsOrResult);
80}
81
83 Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
84 ConversionPatternRewriter &rewriter) {
85 Location loc = op->getLoc();
86 auto elementTy =
87 cast<ShapedType>(op->getOperand(0).getType()).getElementType();
88
89 // tosa::AbsOp
90 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
91 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
92
93 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
94 auto zero = arith::ConstantOp::create(rewriter, loc,
95 rewriter.getZeroAttr(elementTy));
96 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
97 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
98 }
99
100 // tosa::AddOp
101 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
102 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
103
104 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
105 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
106
107 // tosa::SubOp
108 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
109 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
110
111 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
112 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
113
114 // tosa::IntDivOp
115 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
116 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
117
118 // tosa::ReciprocalOp
119 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
120 auto one =
121 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
122 return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
123 }
124
125 // tosa::MulOp
126 if (isa<tosa::MulOp>(op)) {
127 auto shiftVal = cast<tosa::MulOp>(op).getShift();
128 DenseElementsAttr shiftElem;
129 bool shiftIsConstant = true;
130 int32_t shift = 0;
131 if (matchPattern(shiftVal, m_Constant(&shiftElem)))
132 shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
133 else
134 shiftIsConstant = false;
135
136 if (isa<FloatType>(elementTy)) {
137 if (shift != 0) {
138 (void)rewriter.notifyMatchFailure(op,
139 "Cannot have shift value for float");
140 return nullptr;
141 }
142 return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
143 args[1]);
144 }
145
146 if (isa<IntegerType>(elementTy)) {
147 Value a = args[0];
148 Value b = args[1];
149
150 if (shift > 0 || !shiftIsConstant) {
151 Value shiftConst;
152 if (shiftIsConstant)
153 shiftConst = arith::ConstantIntOp::create(rewriter, loc, shift,
154 /*bitwidth=*/8);
155
156 if (!a.getType().isInteger(32))
157 a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
158
159 if (!b.getType().isInteger(32))
160 b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
161
162 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
163 auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
164 RoundingMode::SINGLE_ROUND);
165 auto result =
166 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
167 b, shiftAmount, roundingAttr);
168
169 return result;
170 }
171
172 int aWidth = a.getType().getIntOrFloatBitWidth();
173 int bWidth = b.getType().getIntOrFloatBitWidth();
174 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
175
176 if (aWidth < cWidth)
177 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
178 if (bWidth < cWidth)
179 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
180
181 return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
182 }
183 }
184
185 // tosa::NegateOp
186 if (isa<tosa::NegateOp>(op)) {
187 auto negate = cast<tosa::NegateOp>(op);
188
189 int64_t inZp = 0, outZp = 0;
190 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
191 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
192 bool hasInZp = !failed(maybeInZp);
193 bool hasOutZp = !failed(maybeOutZp);
194 if (hasInZp)
195 inZp = *maybeInZp;
196 if (hasOutZp)
197 outZp = *maybeOutZp;
198
199 if (isa<FloatType>(elementTy))
200 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
201
202 if (isa<IntegerType>(elementTy)) {
203 Value zpAddValue;
204 Type intermediateType;
205 // Compute the maximum value that can occur in the intermediate buffer.
206 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
207 int intermediateBitWidth = 64;
208
209 if (hasInZp && hasOutZp) {
210 // Compute the maximum value that can occur in the intermediate buffer.
211 const int64_t zpAdd = inZp + outZp;
212 const int64_t maxValue =
213 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
214 std::abs(zpAdd) + 1;
215
216 // Convert that maximum value into the maximum bitwidth needed to
217 // represent it.
218 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
219 intermediateBitWidth = 16;
220 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
221 intermediateBitWidth = 32;
222 }
223
224 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
225 zpAddValue = arith::ConstantOp::create(
226 rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
227 } else {
228 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
229 auto arg1 =
230 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[1]);
231 auto arg2 =
232 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[2]);
233 zpAddValue =
234 arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
235 }
236
237 // The negation can be applied by doing:
238 // outputValue = inZp + outZp - inputValue
239 auto ext =
240 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]);
241 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
242
243 // Clamp to the negation range.
245 rewriter, loc, intermediateType,
246 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
248 rewriter, loc, intermediateType,
249 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
250 auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
251
252 // Truncate to the final value.
253 return arith::TruncIOp::create(rewriter, loc, elementTy, clamp);
254 }
255 }
256
257 // tosa::BitwiseAndOp
258 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
259 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
260
261 // tosa::BitwiseOrOp
262 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
263 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
264
265 // tosa::BitwiseNotOp
266 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
267 auto allOnesAttr = rewriter.getIntegerAttr(
268 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
269 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
270 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
271 }
272
273 // tosa::BitwiseXOrOp
274 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
275 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
276
277 // tosa::LogicalLeftShiftOp
278 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
279 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
280
281 // tosa::LogicalRightShiftOp
282 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
283 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
284
285 // tosa::ArithmeticRightShiftOp
286 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
287 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
288 auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
289 if (!round) {
290 return result;
291 }
292
293 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
294 auto one = arith::ConstantOp::create(rewriter, loc,
295 IntegerAttr::get(elementTy, 1));
296 auto zero = arith::ConstantOp::create(rewriter, loc,
297 IntegerAttr::get(elementTy, 0));
298 auto i1zero =
299 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
300 auto i1one =
301 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
302
303 // Checking that input2 != 0
304 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
305 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
306
307 // Checking for the last bit of input1 to be 1
308 auto subtract =
309 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
310 auto shifted =
311 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
312 ->getResults();
313 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
315 auto isInputOdd =
316 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
317 // shifted, truncated, isInputOdd can be poison when input2 is 0.
318 auto shouldRound = arith::SelectOp::create(
319 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
320 auto extended =
321 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
322 return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
323 }
324
325 // tosa::ClzOp
326 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
327 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
328 }
329
330 // tosa::LogicalAnd
331 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
332 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
333
334 // tosa::LogicalNot
335 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
336 auto one = arith::ConstantOp::create(rewriter, loc,
337 rewriter.getIntegerAttr(elementTy, 1));
338 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
339 }
340
341 // tosa::LogicalOr
342 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
343 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
344
345 // tosa::LogicalXor
346 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
347 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
348
349 // tosa::PowOp
350 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
351 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
352
353 // tosa::RsqrtOp
354 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
355 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
356
357 // tosa::LogOp
358 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
359 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
360
361 // tosa::ExpOp
362 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
363 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
364
365 // tosa::SinOp
366 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
367 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
368
369 // tosa::CosOp
370 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
371 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
372
373 // tosa::TanhOp
374 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
375 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
376
377 // tosa::ErfOp
378 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
379 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
380
381 // tosa::GreaterOp
382 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
383 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
384 args[0], args[1]);
385
386 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
387 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
388 args[0], args[1]);
389
390 // tosa::GreaterEqualOp
391 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
392 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
393 args[0], args[1]);
394
395 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
396 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
397 args[0], args[1]);
398
399 // tosa::EqualOp
400 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
401 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
402 args[0], args[1]);
403
404 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
405 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
406 args[0], args[1]);
407
408 // tosa::SelectOp
409 if (isa<tosa::SelectOp>(op)) {
410 elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
411 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
412 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
413 }
414
415 // tosa::MaximumOp
416 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
417 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
418 return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
419 rewriter, args[0], args[1], max);
420 }
421
422 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
423 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
424 }
425
426 // tosa::MinimumOp
427 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
428 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
429 return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
430 rewriter, args[0], args[1], min);
431 }
432
433 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
434 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
435 }
436
437 // tosa::CeilOp
438 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
439 return math::CeilOp::create(rewriter, loc, resultTypes, args);
440
441 // tosa::FloorOp
442 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
443 return math::FloorOp::create(rewriter, loc, resultTypes, args);
444
445 // tosa::ClampOp
446 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
447 bool losesInfo = false;
448 APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
449 APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
450 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
451 APFloat::rmNearestTiesToEven, &losesInfo);
452 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
453 APFloat::rmNearestTiesToEven, &losesInfo);
454 auto min = arith::ConstantOp::create(
455 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
456 auto max = arith::ConstantOp::create(
457 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
458 auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
459
460 auto clampOp = llvm::cast<tosa::ClampOp>(op);
461 const auto nanMode = clampOp.getNanMode();
462
463 // NaN propagation has no meaning for non floating point types.
464 if (!isa<FloatType>(elementTy))
465 return result;
466
467 // In the case of "PROPAGATE" semantics no compare and selection is
468 // required.
469 if (nanMode == NanPropagationMode::PROPAGATE)
470 return result;
471
472 // In the case of "IGNORE" semantics materialize a comparison
473 // of the current operand to the reduction which will return true for a NaN
474 // argument and then selects between the initial reduction value and the
475 // calculated result based on whether the argument is NaN or not. In pseudo
476 // code:
477 //
478 // reduce<op>(x, init):
479 // result = op(init, x)
480 // return init if x == NaN else result
481
482 // Unordered comparison of NaN against itself will always return true.
483 Value isNaN = arith::CmpFOp::create(
484 rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
485 // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
486 // is NaN.
487 return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result);
488 }
489
490 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
491 auto intTy = cast<IntegerType>(elementTy);
492 int64_t min =
493 cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
494 int64_t max =
495 cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
496
497 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
498 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
499 if (intTy.isUnsignedInteger()) {
500 minRepresentable = 0;
501 if (intTy.getIntOrFloatBitWidth() <= 63) {
502 maxRepresentable =
503 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
504 .getZExtValue();
505 }
506 } else if (intTy.getIntOrFloatBitWidth() <= 64) {
507 // Ensure that min & max fit into signed n-bit constants.
508 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
509 .getSExtValue();
510 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
511 .getSExtValue();
512 }
513 // Ensure that the bounds are representable as n-bit signed/unsigned
514 // integers.
515 min = std::max(min, minRepresentable);
516 max = std::max(max, minRepresentable);
517 min = std::min(min, maxRepresentable);
518 max = std::min(max, maxRepresentable);
519
520 auto minVal = arith::ConstantIntOp::create(rewriter, loc, min,
521 intTy.getIntOrFloatBitWidth());
522 auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max,
523 intTy.getIntOrFloatBitWidth());
524 return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
525 intTy.isUnsignedInteger());
526 }
527
528 // tosa::SigmoidOp
529 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
530 auto one =
531 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
532 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
533 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
534 auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
535 return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
536 }
537
538 // tosa::CastOp
539 if (isa<tosa::CastOp>(op)) {
540 Type srcTy = elementTy;
541 Type dstTy = resultTypes.front();
542 if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
543 (void)rewriter.notifyMatchFailure(op, "unsupported type");
544 return nullptr;
545 }
546
547 bool bitExtend =
549
550 if (srcTy == dstTy)
551 return args.front();
552
553 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
554 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
556
557 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
558 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
560
561 // 1-bit integers need to be treated as signless.
562 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
563 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
565
566 if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
567 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
569
570 // Unsigned integers need an unrealized cast so that they can be passed
571 // to UIToFP.
572 if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
573 auto unrealizedCast =
574 UnrealizedConversionCastOp::create(
575 rewriter, loc,
576 rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
577 .getResult(0);
578 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
579 unrealizedCast);
580 }
581
582 // All other si-to-fp conversions should be handled by SIToFP.
583 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
584 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
586
587 // Casting to boolean, floats need to only be checked as not-equal to zero.
588 if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
589 Value zero = arith::ConstantOp::create(rewriter, loc,
590 rewriter.getFloatAttr(srcTy, 0.0));
591 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
592 args.front(), zero);
593 }
594
595 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
596 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
597
598 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
599 // Check whether neither int min nor int max can be represented in the
600 // input floating-point type due to too short exponent range.
601 if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
602 APFloat::semanticsMaxExponent(fltSemantics)) {
603 // Use cmp + select to replace infinites by int min / int max. Other
604 // integral values can be represented in the integer space.
605 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
606 auto posInf = arith::ConstantOp::create(
607 rewriter, loc,
608 rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
609 APFloat::getInf(fltSemantics)));
610 auto negInf = arith::ConstantOp::create(
611 rewriter, loc,
612 rewriter.getFloatAttr(
614 APFloat::getInf(fltSemantics, /*Negative=*/true)));
615 auto overflow = arith::CmpFOp::create(
616 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
617 auto underflow = arith::CmpFOp::create(
618 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
619 auto intMin = arith::ConstantOp::create(
620 rewriter, loc,
621 rewriter.getIntegerAttr(
623 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
624 auto intMax = arith::ConstantOp::create(
625 rewriter, loc,
626 rewriter.getIntegerAttr(
628 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
629 auto maxClamped =
630 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
631 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
632 maxClamped);
633 }
634
635 auto intMinFP = arith::ConstantOp::create(
636 rewriter, loc,
637 rewriter.getFloatAttr(
639 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
640 .getSExtValue()));
641
642 // Check whether the mantissa has enough bits to represent int max.
643 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
644 dstTy.getIntOrFloatBitWidth() - 1) {
645 // Int min can also be represented since it is a power of two and thus
646 // consists of a single leading bit. Therefore we can clamp the input
647 // in the floating-point domain.
648
649 auto intMaxFP = arith::ConstantOp::create(
650 rewriter, loc,
651 rewriter.getFloatAttr(
653 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
654 .getSExtValue()));
655
656 Value clamped =
657 clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
658 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
659 }
660
661 // Due to earlier check we know exponant range is big enough to represent
662 // int min. We can therefore rely on int max + 1 being representable as
663 // well because it's just int min with a positive sign. So clamp the min
664 // value and compare against that to select the max int value if needed.
665 auto intMaxPlusOneFP = arith::ConstantOp::create(
666 rewriter, loc,
667 rewriter.getFloatAttr(
669 static_cast<double>(
670 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
671 .getSExtValue()) +
672 1.0f));
673
674 auto intMax = arith::ConstantOp::create(
675 rewriter, loc,
676 rewriter.getIntegerAttr(
678 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
679 auto minClampedFP =
680 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
681 auto minClamped =
682 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
683 auto overflow = arith::CmpFOp::create(
684 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
685 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
686 minClamped);
687 }
688
689 // Casting to boolean, integers need to only be checked as not-equal to
690 // zero.
691 if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
692 Value zero = arith::ConstantIntOp::create(rewriter, loc, 0,
693 srcTy.getIntOrFloatBitWidth());
694 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
695 args.front(), zero);
696 }
697
698 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
699 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
701
702 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
703 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
704 }
705 }
706
707 (void)rewriter.notifyMatchFailure(
708 op, "unhandled op for linalg body calculation for elementwise op");
709 return nullptr;
710}
711
713
714// Emit an 'arith.constant' op for the given index if it has not been created
715// yet, or return an existing constant. This will prevent an excessive creation
716// of redundant constants, easing readability of emitted code for unit tests.
718 IndexPool &indexPool, int64_t index) {
719 auto [it, inserted] = indexPool.try_emplace(index);
720 if (inserted)
721 it->second =
722 arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index));
723 return it->second;
724}
725
727 IndexPool &indexPool, Value tensor, int64_t index) {
728 auto indexValue = createIndex(rewriter, loc, indexPool, index);
729 return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
730}
731
733 IndexPool &indexPool, Value tensor,
734 int64_t index) {
735 auto shapedType = dyn_cast<ShapedType>(tensor.getType());
736 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
737 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
738 if (shapedType.isDynamicDim(index))
739 return getTensorDim(rewriter, loc, indexPool, tensor, index);
740 return rewriter.getIndexAttr(shapedType.getDimSize(index));
741}
742
743static bool operandsAndResultsRanked(Operation *operation) {
744 auto isRanked = [](Value value) {
745 return isa<RankedTensorType>(value.getType());
746 };
747 return llvm::all_of(operation->getOperands(), isRanked) &&
748 llvm::all_of(operation->getResults(), isRanked);
749}
750
751// Compute the runtime dimension size for dimension 'dim' of the output by
752// inspecting input 'operands', all of which are expected to have the same rank.
753// This function returns a pair {targetSize, masterOperand}.
754//
755// The runtime size of the output dimension is returned either as a statically
756// computed attribute or as a runtime SSA value.
757//
758// If the target size was inferred directly from one dominating operand, that
759// operand is returned in 'masterOperand'. If the target size is inferred from
760// multiple operands, 'masterOperand' is set to nullptr.
761static std::pair<OpFoldResult, Value>
763 ValueRange operands, int64_t dim) {
764 // If any input operand contains a static size greater than 1 for this
765 // dimension, that is the target size. An occurrence of an additional static
766 // dimension greater than 1 with a different value is undefined behavior.
767 for (auto operand : operands) {
768 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
769 if (ShapedType::isStatic(size) && size > 1)
770 return {rewriter.getIndexAttr(size), operand};
771 }
772
773 // Filter operands with dynamic dimension
774 auto operandsWithDynamicDim =
775 llvm::filter_to_vector(operands, [&](Value operand) {
776 return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
777 });
778
779 // If no operand has a dynamic dimension, it means all sizes were 1
780 if (operandsWithDynamicDim.empty())
781 return {rewriter.getIndexAttr(1), operands.front()};
782
783 // Emit code that computes the runtime size for this dimension. If there is
784 // only one operand with a dynamic dimension, it is considered the master
785 // operand that determines the runtime size of the output dimension.
786 auto targetSize =
787 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
788 if (operandsWithDynamicDim.size() == 1)
789 return {targetSize, operandsWithDynamicDim[0]};
790
791 // Calculate maximum size among all dynamic dimensions
792 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
793 auto nextSize =
794 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
795 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
796 }
797 return {targetSize, nullptr};
798}
799
800// Compute the runtime output size for all dimensions. This function returns
801// a pair {targetShape, masterOperands}.
802static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
804 IndexPool &indexPool, ValueRange operands) {
805 assert(!operands.empty());
806 auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
807 SmallVector<OpFoldResult> targetShape;
808 SmallVector<Value> masterOperands;
809 for (auto dim : llvm::seq<int64_t>(0, rank)) {
810 auto [targetSize, masterOperand] =
811 computeTargetSize(rewriter, loc, indexPool, operands, dim);
812 targetShape.push_back(targetSize);
813 masterOperands.push_back(masterOperand);
814 }
815 return {targetShape, masterOperands};
816}
817
819 IndexPool &indexPool, Value operand,
820 int64_t dim, OpFoldResult targetSize,
821 Value masterOperand) {
822 // Nothing to do if this is a static dimension
823 auto rankedTensorType = cast<RankedTensorType>(operand.getType());
824 if (!rankedTensorType.isDynamicDim(dim))
825 return operand;
826
827 // If the target size for this dimension was directly inferred by only taking
828 // this operand into account, there is no need to broadcast. This is an
829 // optimization that will prevent redundant control flow, and constitutes the
830 // main motivation for tracking "master operands".
831 if (operand == masterOperand)
832 return operand;
833
834 // Affine maps for 'linalg.generic' op
835 auto rank = rankedTensorType.getRank();
836 SmallVector<AffineExpr> affineExprs;
837 for (auto index : llvm::seq<int64_t>(0, rank)) {
838 auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
839 : rewriter.getAffineDimExpr(index);
840 affineExprs.push_back(affineExpr);
841 }
842 auto broadcastAffineMap =
843 AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
844 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
845 SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
846
847 // Check if broadcast is necessary
848 auto one = createIndex(rewriter, loc, indexPool, 1);
849 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
850 auto broadcastNecessary = arith::CmpIOp::create(
851 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
852
853 // Emit 'then' region of 'scf.if'
854 auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
855 // It is not safe to cache constants across regions.
856 // New constants could potentially violate dominance requirements.
857 IndexPool localPool;
858
859 // Emit 'tensor.empty' op
860 SmallVector<OpFoldResult> outputTensorShape;
861 for (auto index : llvm::seq<int64_t>(0, rank)) {
862 auto size = index == dim ? targetSize
863 : getOrFoldTensorDim(rewriter, loc, localPool,
864 operand, index);
865 outputTensorShape.push_back(size);
866 }
867 Value outputTensor = tensor::EmptyOp::create(
868 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
869
870 // Emit 'linalg.generic' op
871 auto resultTensor =
872 linalg::GenericOp::create(
873 opBuilder, loc, outputTensor.getType(), operand, outputTensor,
874 affineMaps, getNParallelLoopsAttrs(rank),
875 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
876 // Emit 'linalg.yield' op
877 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
878 })
879 .getResult(0);
880
881 // Cast to original operand type if necessary
882 auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
883 loc, operand.getType(), resultTensor);
884
885 // Emit 'scf.yield' op
886 scf::YieldOp::create(opBuilder, loc, castResultTensor);
887 };
888
889 // Emit 'else' region of 'scf.if'
890 auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
891 scf::YieldOp::create(opBuilder, loc, operand);
892 };
893
894 // Emit 'scf.if' op
895 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
896 emitThenRegion, emitElseRegion);
897 return ifOp.getResult(0);
898}
899
901 IndexPool &indexPool, Value operand,
902 ArrayRef<OpFoldResult> targetShape,
903 ArrayRef<Value> masterOperands) {
904 int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
905 assert((int64_t)targetShape.size() == rank);
906 assert((int64_t)masterOperands.size() == rank);
907 for (auto index : llvm::seq<int64_t>(0, rank))
908 operand =
909 broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
910 targetShape[index], masterOperands[index]);
911 return operand;
912}
913
916 IndexPool &indexPool, ValueRange operands,
917 ArrayRef<OpFoldResult> targetShape,
918 ArrayRef<Value> masterOperands) {
919 // No need to broadcast for unary operations
920 if (operands.size() == 1)
921 return operands;
922
923 // No need to broadcast for static shape
924 bool hasDynamic = false;
925 for (auto op : operands) {
926 const auto tType = dyn_cast<RankedTensorType>(op.getType());
927 if (tType && !tType.hasStaticShape()) {
928 hasDynamic = true;
929 break;
930 }
931 }
932 if (!hasDynamic)
933 return operands;
934
935 // Broadcast dynamic dimensions operand by operand
936 return llvm::map_to_vector(operands, [&](Value operand) {
937 return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
938 targetShape, masterOperands);
939 });
940}
941
942static LogicalResult
943emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
944 Operation *operation, ValueRange operands,
945 ArrayRef<OpFoldResult> targetShape,
946 const TypeConverter &converter) {
947 // Generate output tensor
948 auto resultType = cast_or_null<RankedTensorType>(
949 converter.convertType(operation->getResultTypes().front()));
950 if (!resultType) {
951 return rewriter.notifyMatchFailure(operation, "failed to convert type");
952 }
953 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
954 resultType.getElementType());
955
956 // Create affine maps. Input affine maps broadcast static dimensions of size
957 // 1. The output affine map is an identity map.
958 //
959 auto rank = resultType.getRank();
960 auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
961 auto shape = cast<ShapedType>(operand.getType()).getShape();
962 SmallVector<AffineExpr> affineExprs;
963 for (auto it : llvm::enumerate(shape)) {
964 // Prefer producting identity maps whenever possible (i.e. no broadcasting
965 // needed) because some transforms (like reshape folding)
966 // do not support affine constant exprs.
967 bool requiresBroadcast =
968 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
969 auto affineExpr = requiresBroadcast
970 ? rewriter.getAffineConstantExpr(0)
971 : rewriter.getAffineDimExpr(it.index());
972 affineExprs.push_back(affineExpr);
973 }
974 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
975 });
976 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
977
978 // Emit 'linalg.generic' op
979 bool encounteredError = false;
980 auto linalgOp = linalg::GenericOp::create(
981 rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps,
983 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
985 operation, blockArgs.take_front(operation->getNumOperands()),
986 {resultType.getElementType()}, rewriter);
987 if (!opResult) {
988 encounteredError = true;
989 return;
990 }
991 linalg::YieldOp::create(opBuilder, loc, opResult);
992 });
993 if (encounteredError)
994 return rewriter.notifyMatchFailure(
995 operation, "unable to create linalg.generic body for elementwise op");
996
997 // Cast 'linalg.generic' result into original result type if needed
998 auto castResult = rewriter.createOrFold<tensor::CastOp>(
999 loc, resultType, linalgOp->getResult(0));
1000 rewriter.replaceOp(operation, castResult);
1001 return success();
1002}
1003
1005 ValueRange operands) {
1006 // Shift cannot broadcast
1007 if (isa<tosa::MulOp>(operation)) {
1008 DenseElementsAttr shiftElems;
1009 // Shift cannot broadcast when it is constant
1010 if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
1011 return operands.take_front(2);
1012 else
1013 return operands.take_front(3);
1014 }
1015 if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1016 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1017 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1018 if (failed(maybeOutZp) && failed(maybeInZp))
1019 return operands;
1020 // Input1_zp and output_zp cannot broadcast when they are constants.
1021 return operands.take_front(1);
1022 }
1023 return operands;
1024}
1025
1026static LogicalResult
1028 ConversionPatternRewriter &rewriter,
1029 const TypeConverter &converter) {
1030
1031 // Collect op properties
1032 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1033 assert(operation->getNumOperands() >= 1 &&
1034 "elementwise op expects at least 1 operand");
1035 if (!operandsAndResultsRanked(operation))
1036 return rewriter.notifyMatchFailure(operation,
1037 "Unranked tensors not supported");
1038
1039 // Lower operation
1040 IndexPool indexPool;
1041 auto loc = operation->getLoc();
1042 auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1043 auto [targetShape, masterOperands] =
1044 computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
1045 auto broadcastOperands =
1046 broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
1047 targetShape, masterOperands);
1048 return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
1049 targetShape, converter);
1050}
1051
1052// Returns the constant initial value for a given reduction operation. The
1053// attribute type varies depending on the element type required.
1054static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1055 PatternRewriter &rewriter) {
1056 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1057 return rewriter.getFloatAttr(elementTy, 0.0);
1058
1059 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1060 return rewriter.getIntegerAttr(elementTy, 0);
1061
1062 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1063 return rewriter.getFloatAttr(elementTy, 1.0);
1064
1065 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1066 return rewriter.getIntegerAttr(elementTy, 1);
1067
1068 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1069 return rewriter.getFloatAttr(
1070 elementTy, APFloat::getLargest(
1071 cast<FloatType>(elementTy).getFloatSemantics(), false));
1072
1073 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1074 return rewriter.getIntegerAttr(
1075 elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
1076
1077 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1078 return rewriter.getFloatAttr(
1079 elementTy, APFloat::getLargest(
1080 cast<FloatType>(elementTy).getFloatSemantics(), true));
1081
1082 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1083 return rewriter.getIntegerAttr(
1084 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1085
1086 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1087 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1088
1089 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1090 return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1091
1092 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1093 return rewriter.getFloatAttr(
1094 elementTy, APFloat::getLargest(
1095 cast<FloatType>(elementTy).getFloatSemantics(), true));
1096
1097 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1098 return rewriter.getIntegerAttr(
1099 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1100
1101 return {};
1102}
1103
1104// Creates the body calculation for a reduction. The operations vary depending
1105// on the input type.
1107 ValueRange args,
1108 Type elementTy,
1109 PatternRewriter &rewriter) {
1110 Location loc = op->getLoc();
1111 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1112 return arith::AddFOp::create(rewriter, loc, args);
1113 }
1114
1115 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1116 return arith::AddIOp::create(rewriter, loc, args);
1117 }
1118
1119 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1120 return arith::MulFOp::create(rewriter, loc, args);
1121 }
1122
1123 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1124 return arith::MulIOp::create(rewriter, loc, args);
1125 }
1126
1127 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1128 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1129 }
1130
1131 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1132 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1133 }
1134
1135 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1136 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1137 }
1138
1139 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1140 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1141 }
1142
1143 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1144 return arith::AndIOp::create(rewriter, loc, args);
1145
1146 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1147 return arith::OrIOp::create(rewriter, loc, args);
1148
1149 return {};
1150}
1151
1152// Performs the match and rewrite for reduction operations. This includes
1153// declaring a correctly sized initial value, and the linalg.generic operation
1154// that reduces across the specified axis.
1155template <typename OpTy>
1156static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1157 PatternRewriter &rewriter) {
1158 auto loc = op->getLoc();
1159 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1160 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1161 if (!inputTy || !resultTy)
1162 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1163
1164 auto elementTy = resultTy.getElementType();
1165 Value input = op->getOperand(0);
1166
1167 // Figure out the accType if needed
1168 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1169 isa<FloatType>(elementTy) &&
1170 cast<FloatType>(elementTy).isBF16();
1171 Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
1172
1173 SmallVector<int64_t> reduceShape;
1174 SmallVector<Value> dynDims;
1175 for (unsigned i = 0; i < inputTy.getRank(); i++) {
1176 if (axis != i) {
1177 reduceShape.push_back(inputTy.getDimSize(i));
1178 if (inputTy.isDynamicDim(i))
1179 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1180 }
1181 }
1182
1183 SmallVector<Value> inputs, outputs;
1184 inputs.push_back(input);
1185
1186 // First fill the output buffer with the init value.
1187 auto emptyTensor =
1188 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1189 .getResult();
1190
1191 auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
1192 if (!fillValueAttr)
1193 return rewriter.notifyMatchFailure(
1194 op, "No initial value found for reduction operation");
1195
1196 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1197 auto filledTensor =
1198 linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
1199 ValueRange{emptyTensor})
1200 .result();
1201 outputs.push_back(filledTensor);
1202
1203 bool isNanIgnoreMode = false;
1204 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1205 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1206 // NaN propagation has no meaning for non floating point types.
1207 if (isa<FloatType>(elementTy) &&
1208 op.getNanMode() == NanPropagationMode::IGNORE) {
1209 isNanIgnoreMode = true;
1210 // Because the TOSA spec requires the result be NaN iff all elements in
1211 // the reduction are NaN we can't simply perform a compare and select.
1212 // Additionally we have to keep track of whether we've seen any non-NaN
1213 // values and then do a final select based on this predicate.
1214 auto trueAttr = rewriter.getBoolAttr(true);
1215 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1216 auto emptyBoolTensor =
1217 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1218 trueValue.getType(), dynDims)
1219 .getResult();
1220 auto allResultsNaNTensor =
1221 linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
1222 ValueRange{emptyBoolTensor})
1223 .result();
1224 // Note that because the linalg::ReduceOp has two variadic arguments
1225 // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1226 // need to have the same number of inputs and outputs.
1227 //
1228 // The second input isn't actually used anywhere since the value used to
1229 // update the NaN flag is calculated inside the body of the reduction and
1230 // then used to update an out value.
1231 // In order to satisfy type constraints we just pass another copy of the
1232 // input here.
1233 inputs.push_back(input);
1234 outputs.push_back(allResultsNaNTensor);
1235 }
1236 }
1237
1238 bool didEncounterError = false;
1239 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1240 rewriter, loc, inputs, outputs, axis,
1241 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1242 std::array<Value, 2> binaryArgs{
1243 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1244
1245 // If reduction type differs then extend (applicable to reduce_sum)
1246 if (binaryArgs[0].getType() != accTy)
1247 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1248 binaryArgs[0]);
1249
1250 auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
1251 accTy, rewriter);
1252 if (result)
1253 didEncounterError = true;
1254
1255 SmallVector<Value> resultsToYield;
1256 if (isNanIgnoreMode) {
1257 auto inputValue = blockArgs[0];
1258 auto initialValue = blockArgs[2];
1259 auto oldAllResultsNanFlagValue = blockArgs[3];
1260
1261 // Unordered comparison of NaN against itself will always return true.
1262 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1263 arith::CmpFPredicate::UNO,
1264 inputValue, inputValue);
1265 // If we've encountered a NaN, take the non-NaN value.
1266 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1267 isNaN, initialValue, result);
1268 // Update the flag which keeps track of whether we have seen a non-NaN
1269 // value.
1270 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1271 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1272 resultsToYield.push_back(selectOp);
1273 resultsToYield.push_back(newAllResultsNanFlagValue);
1274 } else {
1275 resultsToYield.push_back(result);
1276 }
1277 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1278 });
1279
1280 if (!didEncounterError)
1281 return rewriter.notifyMatchFailure(
1282 op, "unable to create linalg.generic body for reduce op");
1283
1284 if (isNanIgnoreMode) {
1285 // Materialize a check to see whether we encountered any non-NaN values, if
1286 // we didn't we need to select a tensor of NaNs since the result will just
1287 // be the initial identity value propagated through all the compares and
1288 // selects inside the reduction.
1289
1290 // Create a tensor full of NaNs.
1291 auto nanValueAttr = rewriter.getFloatAttr(
1292 accTy,
1293 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1294 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1295 auto emptyNanTensor =
1296 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1297 .getResult();
1298 auto nanFilledTensor =
1299 linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
1300 ValueRange{emptyNanTensor})
1301 .result();
1302
1303 // Create an empty tensor, non need to fill this since it will be
1304 // overwritten by the select.
1305 auto finalEmptyTensor =
1306 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1307 .getResult();
1308
1309 // Do a selection between the tensors akin to:
1310 // result = NaN if "all results NaN" else result.
1311 SmallVector<Value> ins, outs;
1312 ins.push_back(linalgOp->getOpResult(1));
1313 ins.push_back(nanFilledTensor);
1314 ins.push_back(linalgOp->getResult(0));
1315 outs.push_back(finalEmptyTensor);
1316 auto linalgSelect =
1317 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1318 linalgOp = linalgSelect;
1319 }
1320
1321 // Truncate back to resultTy if needed
1322 Value reducedRes = linalgOp->getResult(0);
1323 if (widenAccTy) {
1324 auto resEmptyOp =
1325 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1326 .getResult();
1327
1328 const unsigned reducedRank =
1329 cast<ShapedType>(reducedRes.getType()).getRank();
1330 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
1331 reducedRes =
1332 linalg::GenericOp::create(
1333 rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes},
1334 ValueRange{resEmptyOp},
1335 ArrayRef<AffineMap>{identityMap, identityMap},
1336 getNParallelLoopsAttrs(reducedRank),
1337 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1338 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1339 elementTy, args[0]);
1340 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1341 })
1342 .getResults()[0];
1343 }
1344
1345 SmallVector<ReassociationExprs, 4> reassociationMap;
1346 uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
1347 reassociationMap.resize(expandInputRank);
1348
1349 for (uint64_t i = 0; i < expandInputRank; i++) {
1350 int32_t dimToPush = i > axis ? i + 1 : i;
1351 reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1352 }
1353
1354 if (expandInputRank != 0) {
1355 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1356 reassociationMap[expandedDim].push_back(
1357 rewriter.getAffineDimExpr(expandedDim + 1));
1358 }
1359
1360 // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1361 // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1362 // not have access to such information. This matters when handling dynamically
1363 // sized tensors.
1364 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
1365 reassociationMap);
1366 return success();
1367}
1368
1369namespace {
1370
1371template <typename SrcOp>
1372class PointwiseConverter : public OpConversionPattern<SrcOp> {
1373public:
1374 using OpConversionPattern<SrcOp>::OpConversionPattern;
1375 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1376
1377 LogicalResult
1378 matchAndRewrite(SrcOp op, OpAdaptor operands,
1379 ConversionPatternRewriter &rewriter) const final {
1381 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1382 }
1383};
1384
1385// Collapse tensor<1xiN> into tensor<iN>
1386// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1387static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
1388 Location loc) {
1390 // Create the collapsed type
1391 auto inputType = cast<RankedTensorType>(input.getType());
1392 auto elemType = inputType.getElementType();
1393 auto collapsedType = RankedTensorType::get({}, elemType);
1394 // Emit the collapse op
1395 return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1396 reassociation);
1397}
1398
1400convertToI8(const llvm::SmallVector<int32_t> &input) {
1402 output.reserve(input.size());
1403
1404 for (auto v : llvm::map_range(
1405 input, [](int32_t val) { return static_cast<int8_t>(val); })) {
1406 output.push_back(v);
1407 }
1408 return output;
1409}
1410
1411// The shift or multiplier may be either constant or non-constant, depending on
1412// whether dynamic extension is enabled.
1413// - If the shift or multiplier is non-constant, add it as an input to
1414// linalg::GenericOp by:
1415// 1. Pushing it into 'genericInputs'.
1416// 2. Appending a corresponding affine map to 'indexingMaps'.
1417// - If the shift or multiplier is constant, set 'constant' instead.
1418static void setupLinalgGenericOpInputAndIndexingMap(
1420 SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1421 bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
1422 bool isShift = false) {
1423
1424 auto loc = op.getLoc();
1425 auto inputTy = cast<ShapedType>(op.getInput().getType());
1426 unsigned rank = inputTy.getRank();
1427 SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
1428
1429 if (isConstant) {
1430 // If we are rescaling per-channel then we need to store the
1431 // values in a buffer.
1432 if (values.size() == 1) {
1433 IntegerAttr intAttr = isShift
1434 ? rewriter.getI8IntegerAttr(values.front())
1435 : rewriter.getI32IntegerAttr(values.front());
1436 constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1437 } else {
1438 auto elementType =
1439 isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
1440 auto tensorType = RankedTensorType::get(
1441 {static_cast<int64_t>(values.size())}, elementType);
1442 DenseIntElementsAttr EltAttr;
1443 if (isShift)
1444 EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
1445 else
1446 EltAttr = DenseIntElementsAttr::get(tensorType, values);
1447 genericInputs.push_back(
1448 arith::ConstantOp::create(rewriter, loc, EltAttr));
1449 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1450 /*symbolCount=*/0, exprs,
1451 rewriter.getContext()));
1452 }
1453 } else {
1454 // If we are not rescaling per-channel then we need to collapse 1xN to N
1455 // and push broadcastMap.
1456 auto operand = isShift ? op.getShift() : op.getMultiplier();
1457 auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1458 if (tensorType && tensorType.hasStaticShape() &&
1459 tensorType.getShape()[0] == 1) {
1460 // broadcastMap = affine_map<(d0, d1) -> ()>
1461 // It would affect as broadcast for scalar values in linalg::GenericOp.
1462 AffineMap broadcastMap =
1463 AffineMap::get(rank, 0, {}, rewriter.getContext());
1464 genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1465 indexingMaps.push_back(broadcastMap);
1466 } else {
1467 genericInputs.push_back(operand);
1468 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1469 /*symbolCount=*/0, exprs,
1470 rewriter.getContext()));
1471 }
1472 }
1473 arg = indexingMaps.size() - 1;
1474}
1475
1476// Return the extended Zp to be used in subsequent arithmetic operations.
1477static Value getExtendZp(OpBuilder &builder, Type valueTy,
1478 FailureOr<int64_t> maybeZp, Location loc,
1479 ValueRange blockArgs, int64_t zpArg,
1480 bool isOutputZp = false) {
1481 Value result;
1482 const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
1483 const uint32_t attrBitwidth =
1484 isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1485 auto extendType = builder.getIntegerType(attrBitwidth);
1486 // The Zp value can be either constant or non-constant, depending on
1487 // whether dynamic extension is enabled.
1488 // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1489 // be passed as an input to linalg::GenericOp.
1490 if (failed(maybeZp)) {
1491 result = blockArgs[zpArg];
1492 auto zpTy = result.getType();
1493 if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1494 // For ExtUIOp, the input must be signless.
1495 // UnrealizedConversionCastOp will cast the input to signless type.
1496 if (zpTy.isUnsignedInteger()) {
1497 result =
1498 UnrealizedConversionCastOp::create(
1499 builder, loc,
1500 builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
1501 .getResult(0);
1502 }
1503 if (zpTy.isUnsignedInteger()) {
1504 return arith::ExtUIOp::create(builder, loc, extendType, result);
1505 } else {
1506 return arith::ExtSIOp::create(builder, loc, extendType, result);
1507 }
1508 }
1509 } else {
1510 return arith::ConstantOp::create(builder, loc,
1511 IntegerAttr::get(extendType, *maybeZp));
1512 }
1513 return result;
1514}
1515
1516class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1517public:
1518 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1519
1520 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1521 PatternRewriter &rewriter) const final {
1522 auto loc = op.getLoc();
1523 auto input = op.getInput();
1524 auto inputTy = cast<ShapedType>(op.getInput().getType());
1525 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1526 unsigned rank = inputTy.getRank();
1527
1528 // This is an illegal configuration. terminate and log an error
1529 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1530 return rewriter.notifyMatchFailure(
1531 op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1532 "currently supported");
1533 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1534 return rewriter.notifyMatchFailure(
1535 op, "tosa.rescale requires scale32 for double_round to be true");
1536
1537 if (!isa<IntegerType>(inputTy.getElementType()))
1538 return rewriter.notifyMatchFailure(op, "only support integer type");
1539
1540 SmallVector<Value> dynDims;
1541 for (int i = 0; i < outputTy.getRank(); i++) {
1542 if (outputTy.isDynamicDim(i)) {
1543 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1544 }
1545 }
1546
1547 DenseElementsAttr shiftElems;
1548 bool isShiftConstant = false;
1549 if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
1550 isShiftConstant = true;
1551
1552 DenseElementsAttr multiplierElems;
1553 bool isMultiplierConstant = false;
1554 if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1555 isMultiplierConstant = true;
1556
1557 llvm::SmallVector<int32_t> shiftValues;
1558 llvm::SmallVector<int32_t> multiplierValues;
1559 bool doubleRound;
1560
1561 if (isMultiplierConstant && isShiftConstant) {
1562 // explicit cast is required here
1563 shiftValues = llvm::to_vector(llvm::map_range(
1564 shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1565 return static_cast<int32_t>(attr.getInt());
1566 }));
1567 multiplierValues = llvm::to_vector(
1568 llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1569 [](IntegerAttr attr) -> int32_t {
1570 return static_cast<int32_t>(attr.getInt());
1571 }));
1572
1573 // If we shift by more than the bitwidth, this just sets to 0.
1574 for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1575 if (shiftValues[i] > 63) {
1576 shiftValues[i] = 0;
1577 multiplierValues[i] = 0;
1578 }
1579 }
1580 // Double round only occurs if shift is greater than 31, check that this
1581 // is ever true.
1582 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1583 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1584 } else
1585 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1586
1587 RoundingMode roundingMode =
1588 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1589
1590 SmallVector<AffineMap> indexingMaps = {
1591 rewriter.getMultiDimIdentityMap(rank)};
1592 SmallVector<Value, 4> genericInputs = {input};
1593
1594 // If we are rescaling per-channel then we need to store the multiplier
1595 // values in a buffer.
1596 Value multiplierConstant;
1597 int64_t multiplierArg = 0;
1598 setupLinalgGenericOpInputAndIndexingMap(
1599 rewriter, multiplierValues, genericInputs, indexingMaps,
1600 isMultiplierConstant, op, multiplierConstant, multiplierArg);
1601
1602 // If we are rescaling per-channel then we need to store the shift
1603 // values in a buffer.
1604 Value shiftConstant;
1605 int64_t shiftArg = 0;
1606 setupLinalgGenericOpInputAndIndexingMap(
1607 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1608 shiftConstant, shiftArg, true);
1609
1610 // broadcastMap = affine_map<(d0, d1) -> ()>
1611 // It would affect as broadcast for scalar values in linalg::GenericOp.
1612 AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
1613 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1614 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1615 // The inputZp and outputZp may be either constant or non-constant,
1616 // depending on whether dynamic extension is enabled.
1617 // - If the zp's are non-constant, add them as an inputs to
1618 // linalg::GenericOp by:
1619 // 1. Pushing it into 'genericInputs'.
1620 // 2. Appending a corresponding affine map to 'indexingMaps'.
1621 // - If the zp's are constant, they would be generated as arith.constant.
1622 int64_t iZpArg = 0;
1623 if (failed(maybeIZp)) {
1624 genericInputs.push_back(
1625 collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1626 indexingMaps.push_back(broadcastMap);
1627 iZpArg = indexingMaps.size() - 1;
1628 }
1629 int64_t oZpArg = 0;
1630 if (failed(maybeOZp)) {
1631 genericInputs.push_back(
1632 collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1633 indexingMaps.push_back(broadcastMap);
1634 oZpArg = indexingMaps.size() - 1;
1635 }
1636
1637 // Indexing maps for output values.
1638 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1639
1640 // Construct the indexing maps needed for linalg.generic ops.
1641 Value emptyTensor = tensor::EmptyOp::create(
1642 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1643 ArrayRef<Value>({dynDims}));
1644
1645 auto linalgOp = linalg::GenericOp::create(
1646 rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor},
1647 indexingMaps, getNParallelLoopsAttrs(rank),
1648 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1649 ValueRange blockArgs) {
1650 Value value = blockArgs[0];
1651 Type valueTy = value.getType();
1652
1653 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1654 auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1655 nestedLoc, blockArgs, iZpArg);
1656
1657 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1658 auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1659 nestedLoc, blockArgs, oZpArg, true);
1660
1661 IntegerType outIntType =
1662 cast<IntegerType>(blockArgs.back().getType());
1663 unsigned outBitWidth = outIntType.getWidth();
1664 assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1665
1666 Value multiplier = multiplierConstant ? multiplierConstant
1667 : blockArgs[multiplierArg];
1668 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1669
1670 if (valueTy.isUnsignedInteger()) {
1671 value = UnrealizedConversionCastOp::create(
1672 nestedBuilder, nestedLoc,
1673 nestedBuilder.getIntegerType(
1674 valueTy.getIntOrFloatBitWidth()),
1675 value)
1676 .getResult(0);
1677 }
1678 if (valueTy.getIntOrFloatBitWidth() < 32) {
1679 if (op.getInputUnsigned()) {
1680 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1681 nestedBuilder.getI32Type(), value);
1682 } else {
1683 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1684 nestedBuilder.getI32Type(), value);
1685 }
1686 }
1687
1688 value =
1689 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1690
1691 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1692 nestedBuilder.getI32Type(), value,
1693 multiplier, shift, roundingMode);
1694
1695 // Move to the new zero-point.
1696 value =
1697 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1698
1699 // Saturate to the output size.
1700 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1701 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1702
1703 // Unsigned integers have a difference output value.
1704 if (op.getOutputUnsigned()) {
1705 intMin = 0;
1706 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1707 }
1708
1709 auto intMinVal = arith::ConstantOp::create(
1710 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1711 auto intMaxVal = arith::ConstantOp::create(
1712 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1713
1714 value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1715 nestedBuilder, /*isUnsigned=*/false);
1716
1717 if (outIntType.getWidth() < 32) {
1718 value = arith::TruncIOp::create(
1719 nestedBuilder, nestedLoc,
1720 rewriter.getIntegerType(outIntType.getWidth()), value);
1721 }
1722
1723 if (outIntType.isUnsignedInteger()) {
1724 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1725 outIntType, value)
1726 .getResult(0);
1727 }
1728 linalg::YieldOp::create(nestedBuilder, loc, value);
1729 });
1730
1731 rewriter.replaceOp(op, linalgOp->getResults());
1732 return success();
1733 }
1734};
1735
1736// Handle the resize case where the input is a 1x1 image. This case
1737// can entirely avoiding having extract operations which target much
1738// more difficult to optimize away.
1739class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1740public:
1741 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1742
1743 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1744 PatternRewriter &rewriter) const final {
1745 Location loc = op.getLoc();
1746 ImplicitLocOpBuilder builder(loc, rewriter);
1747 auto input = op.getInput();
1748 auto inputTy = cast<RankedTensorType>(input.getType());
1749 auto resultTy = cast<RankedTensorType>(op.getType());
1750 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1751
1752 auto inputH = inputTy.getDimSize(1);
1753 auto inputW = inputTy.getDimSize(2);
1754 auto outputH = resultTy.getDimSize(1);
1755 auto outputW = resultTy.getDimSize(2);
1756
1757 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1758 return rewriter.notifyMatchFailure(
1759 op, "tosa.resize is not a pure 1x1->1x1 image operation");
1760
1761 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1762 op.getMode() != ResizeMode::BILINEAR)
1763 return rewriter.notifyMatchFailure(
1764 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1765
1766 if (inputTy == resultTy) {
1767 rewriter.replaceOp(op, input);
1768 return success();
1769 }
1770
1771 SmallVector<int64_t> scale;
1772 if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
1773 return failure();
1774 }
1775
1776 // Collapse the unit width and height away.
1777 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1778 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1779 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1780 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1781 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1782
1783 auto collapseTy =
1784 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1785 inputTy.getElementType());
1786 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1787 reassociationMap);
1788
1789 // Get any dynamic shapes that appear in the input format.
1790 llvm::SmallVector<Value> outputDynSize;
1791 if (inputTy.isDynamicDim(0))
1792 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1793 if (inputTy.isDynamicDim(3))
1794 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1795
1796 // Generate the elementwise operation for casting scaling the input value.
1797 auto genericTy = collapseTy.clone(resultTy.getElementType());
1798 Value empty =
1799 tensor::EmptyOp::create(builder, genericTy.getShape(),
1800 resultTy.getElementType(), outputDynSize);
1801 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1802 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1803 utils::IteratorType::parallel);
1804
1805 auto generic = linalg::GenericOp::create(
1806 builder, genericTy, ValueRange{collapse}, ValueRange{empty},
1807 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1808 [=](OpBuilder &b, Location loc, ValueRange args) {
1809 Value value = args[0];
1810 // This is the quantized case.
1811 if (inputTy.getElementType() != resultTy.getElementType()) {
1812 value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
1813 value);
1814
1815 if (isBilinear && scale[0] != 0) {
1816 Value scaleY = arith::ConstantOp::create(
1817 b, loc, b.getI32IntegerAttr(scale[0]));
1818 value = arith::MulIOp::create(b, loc, value, scaleY);
1819 }
1820
1821 if (isBilinear && scale[2] != 0) {
1822 Value scaleX = arith::ConstantOp::create(
1823 b, loc, b.getI32IntegerAttr(scale[2]));
1824 value = arith::MulIOp::create(b, loc, value, scaleX);
1825 }
1826 }
1827
1828 linalg::YieldOp::create(b, loc, value);
1829 });
1830
1831 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1832 op, resultTy, generic.getResults()[0], reassociationMap);
1833 return success();
1834 }
1835};
1836
1837// TOSA resize with width or height of 1 may be broadcasted to a wider
1838// dimension. This is done by materializing a new tosa.resize without
1839// the broadcasting behavior, and an explicit broadcast afterwards.
1840class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1841public:
1842 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1843
1844 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1845 PatternRewriter &rewriter) const final {
1846 Location loc = op.getLoc();
1847 ImplicitLocOpBuilder builder(loc, rewriter);
1848 auto input = op.getInput();
1849 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1850 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1851
1852 if (!inputTy || !resultTy)
1853 return rewriter.notifyMatchFailure(op,
1854 "requires ranked input/output types");
1855
1856 auto batch = inputTy.getDimSize(0);
1857 auto channels = inputTy.getDimSize(3);
1858 auto inputH = inputTy.getDimSize(1);
1859 auto inputW = inputTy.getDimSize(2);
1860 auto outputH = resultTy.getDimSize(1);
1861 auto outputW = resultTy.getDimSize(2);
1862
1863 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1864 return rewriter.notifyMatchFailure(
1865 op, "tosa.resize has no broadcasting behavior");
1866
1867 // For any dimension that is broadcastable we generate a width of 1
1868 // on the output.
1869 llvm::SmallVector<int64_t> resizeShape;
1870 resizeShape.push_back(batch);
1871 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1872 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1873 resizeShape.push_back(channels);
1874
1875 auto resizeTy = resultTy.clone(resizeShape);
1876 auto resize =
1877 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1878 op.getOffset(), op.getBorder(), op.getMode());
1879
1880 // Collapse an unit result dims.
1881 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1882 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1883 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1884 if (inputH != 1)
1885 reassociationMap.push_back({});
1886 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1887 if (inputW != 1)
1888 reassociationMap.push_back({});
1889 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1890
1891 llvm::SmallVector<int64_t> collapseShape = {batch};
1892 if (inputH != 1)
1893 collapseShape.push_back(outputH);
1894 if (inputW != 1)
1895 collapseShape.push_back(outputW);
1896 collapseShape.push_back(channels);
1897
1898 auto collapseTy = resultTy.clone(collapseShape);
1899 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1900 resize, reassociationMap);
1901
1902 // Broadcast the collapsed shape to the output result.
1903 llvm::SmallVector<Value> outputDynSize;
1904 if (inputTy.isDynamicDim(0))
1905 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1906 if (inputTy.isDynamicDim(3))
1907 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1908
1909 SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1910 utils::IteratorType::parallel);
1911 Value empty = tensor::EmptyOp::create(
1912 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1913
1914 SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1915 if (inputH != 1)
1916 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1917 if (inputW != 1)
1918 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1919 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1920
1921 auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1922 inputExprs, rewriter.getContext());
1923
1924 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1925 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1926 op, resultTy, ValueRange{collapse}, ValueRange{empty},
1927 ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1928 [=](OpBuilder &b, Location loc, ValueRange args) {
1929 Value value = args[0];
1930 linalg::YieldOp::create(b, loc, value);
1931 });
1932
1933 return success();
1934 }
1935};
1936
1937class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1938public:
1939 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1940
1941 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1942 PatternRewriter &rewriter) const final {
1943 Location loc = op.getLoc();
1944 ImplicitLocOpBuilder b(loc, rewriter);
1945 auto input = op.getInput();
1946 auto inputTy = cast<ShapedType>(input.getType());
1947 auto resultTy = cast<ShapedType>(op.getType());
1948 auto resultETy = resultTy.getElementType();
1949
1950 bool floatingPointMode = isa<FloatType>(resultETy);
1951 auto floatTy = resultETy;
1952
1953 auto imageH = inputTy.getShape()[1];
1954 auto imageW = inputTy.getShape()[2];
1955
1956 auto dynamicDimsOr =
1957 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1958 if (!dynamicDimsOr.has_value())
1959 return rewriter.notifyMatchFailure(
1960 op, "unable to get dynamic dimensions of tosa.resize");
1961
1962 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1963 op.getMode() != ResizeMode::BILINEAR)
1964 return rewriter.notifyMatchFailure(
1965 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1966
1967 SmallVector<AffineMap, 2> affineMaps = {
1968 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1969 auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
1970 resultETy, *dynamicDimsOr);
1971 auto genericOp = linalg::GenericOp::create(
1972 b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1973 getNParallelLoopsAttrs(resultTy.getRank()));
1974 Value resize = genericOp.getResult(0);
1975
1976 {
1977 OpBuilder::InsertionGuard regionGuard(b);
1978 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1979 TypeRange({resultETy}), loc);
1980 Value batch = linalg::IndexOp::create(b, 0);
1981 Value y = linalg::IndexOp::create(b, 1);
1982 Value x = linalg::IndexOp::create(b, 2);
1983 Value channel = linalg::IndexOp::create(b, 3);
1984
1985 Value zeroI32 =
1986 arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
1987 Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
1988 Value hMax =
1989 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
1990 Value wMax =
1991 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
1992
1993 Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
1994 Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
1995
1996 SmallVector<int64_t> scale, offset, border;
1997 if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
1998 !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
1999 !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
2000 return rewriter.notifyMatchFailure(
2001 op, "tosa.resize scale/offset/border should have compile time "
2002 "constant values.");
2003 }
2004
2005 Value yScaleN, yScaleD, xScaleN, xScaleD;
2006 yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
2007 yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
2008 xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
2009 xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
2010
2011 Value yOffset, xOffset, yBorder, xBorder;
2012 yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
2013 xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
2014 yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
2015 xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
2016
2017 // Compute the ix and dx values for both the X and Y dimensions.
2018 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
2019 Value scaleN, Value scaleD, Value offset,
2020 int size, ImplicitLocOpBuilder &b) {
2021 if (size == 1) {
2022 index = zeroI32;
2023 delta = zeroFp;
2024 return;
2025 }
2026 // x = x * scale_d + offset;
2027 // ix = floor(x / scale_n)
2028 Value val = arith::MulIOp::create(b, in, scaleD);
2029 val = arith::AddIOp::create(b, val, offset);
2030 index = arith::FloorDivSIOp::create(b, val, scaleN);
2031
2032 // rx = x % scale_n
2033 // dx = rx / scale_n
2034 Value r = arith::RemSIOp::create(b, val, scaleN);
2035 Value rFp = arith::SIToFPOp::create(b, floatTy, r);
2036 Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
2037 delta = arith::DivFOp::create(b, rFp, scaleNfp);
2038 };
2039
2040 // Compute the ix and dx values for the X and Y dimensions - int case.
2041 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
2042 Value scaleN, Value scaleD, Value offset,
2043 int size, ImplicitLocOpBuilder &b) {
2044 if (size == 1) {
2045 index = zeroI32;
2046 delta = zeroI32;
2047 return;
2048 }
2049 // x = x * scale_d + offset;
2050 // ix = floor(x / scale_n)
2051 // dx = x - ix * scale_n;
2052 Value val = arith::MulIOp::create(b, in, scaleD);
2053 val = arith::AddIOp::create(b, val, offset);
2054 index = arith::DivSIOp::create(b, val, scaleN);
2055 delta = arith::MulIOp::create(b, index, scaleN);
2056 delta = arith::SubIOp::create(b, val, delta);
2057 };
2058
2059 Value ix, iy, dx, dy;
2060 if (floatingPointMode) {
2061 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2062 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2063 } else {
2064 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2065 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2066 }
2067
2068 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2069 auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2070
2071 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
2072 Value max, int size,
2073 ImplicitLocOpBuilder &b) -> Value {
2074 if (size == 1) {
2076 }
2077
2078 Value pred;
2079 if (floatingPointMode) {
2080 auto h =
2081 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
2082 pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
2083 } else {
2084 Value dvalDouble = arith::ShLIOp::create(b, dval, one);
2085 pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
2086 dvalDouble, scale);
2087 }
2088
2089 auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
2090 val = arith::AddIOp::create(b, val, offset);
2091 val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
2092 return arith::IndexCastOp::create(b, b.getIndexType(), val);
2093 };
2094
2095 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
2096 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
2097
2098 Value result = tensor::ExtractOp::create(
2099 b, input, ValueRange{batch, iy, ix, channel});
2100
2101 linalg::YieldOp::create(b, result);
2102 } else {
2103 // The mode here must be BILINEAR.
2104 assert(op.getMode() == ResizeMode::BILINEAR);
2105
2106 auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2107
2108 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
2109 Value max, ImplicitLocOpBuilder &b) {
2110 val0 = in;
2111 val1 = arith::AddIOp::create(b, val0, oneVal);
2112 val0 =
2113 clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
2114 val1 =
2115 clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
2116 val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
2117 val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
2118 };
2119
2120 // Linalg equivalent to the section below:
2121 // int16_t iy0 = apply_max(iy, 0);
2122 // int16_t iy1 = apply_min(iy + 1, IH - 1);
2123 // int16_t ix0 = apply_max(ix, 0);
2124 // int16_t ix1 = apply_min(ix + 1, IW - 1);
2125 Value x0, x1, y0, y1;
2126 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
2127 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
2128
2129 Value y0x0 = tensor::ExtractOp::create(
2130 b, input, ValueRange{batch, y0, x0, channel});
2131 Value y0x1 = tensor::ExtractOp::create(
2132 b, input, ValueRange{batch, y0, x1, channel});
2133 Value y1x0 = tensor::ExtractOp::create(
2134 b, input, ValueRange{batch, y1, x0, channel});
2135 Value y1x1 = tensor::ExtractOp::create(
2136 b, input, ValueRange{batch, y1, x1, channel});
2137
2138 if (floatingPointMode) {
2139 auto oneVal =
2140 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
2141 auto interpolate = [&](Value val0, Value val1, Value delta,
2142 int inputSize,
2143 ImplicitLocOpBuilder &b) -> Value {
2144 if (inputSize == 1)
2145 return val0;
2146 Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
2147 Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
2148 Value mul1 = arith::MulFOp::create(b, val1, delta);
2149 return arith::AddFOp::create(b, mul0, mul1);
2150 };
2151
2152 // Linalg equivalent to the section below:
2153 // topAcc = v00 * (unit_x - dx);
2154 // topAcc += v01 * dx;
2155 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
2156
2157 // Linalg equivalent to the section below:
2158 // bottomAcc = v10 * (unit_x - dx);
2159 // bottomAcc += v11 * dx;
2160 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
2161
2162 // Linalg equivalent to the section below:
2163 // result = topAcc * (unit_y - dy) + bottomAcc * dy
2164 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
2165 linalg::YieldOp::create(b, result);
2166 } else {
2167 // Perform in quantized space.
2168 y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
2169 y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
2170 y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
2171 y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
2172
2173 const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
2174 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2175 dx = arith::ExtSIOp::create(b, resultETy, dx);
2176 dy = arith::ExtSIOp::create(b, resultETy, dy);
2177 }
2178
2179 Value yScaleNExt = yScaleN;
2180 Value xScaleNExt = xScaleN;
2181
2182 const int64_t scaleBitwidth =
2183 xScaleN.getType().getIntOrFloatBitWidth();
2184 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2185 yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
2186 xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
2187 }
2188
2189 auto interpolate = [](Value val0, Value val1, Value weight1,
2190 Value scale, int inputSize,
2191 ImplicitLocOpBuilder &b) -> Value {
2192 if (inputSize == 1)
2193 return arith::MulIOp::create(b, val0, scale);
2194 Value weight0 = arith::SubIOp::create(b, scale, weight1);
2195 Value mul0 = arith::MulIOp::create(b, val0, weight0);
2196 Value mul1 = arith::MulIOp::create(b, val1, weight1);
2197 return arith::AddIOp::create(b, mul0, mul1);
2198 };
2199
2200 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2201 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2202 Value result =
2203 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2204 linalg::YieldOp::create(b, result);
2205 }
2206 }
2207 }
2208
2209 rewriter.replaceOp(op, resize);
2210 return success();
2211 }
2212};
2213
2214// At the codegen level any identity operations should be removed. Any cases
2215// where identity is load-bearing (e.g. cross device computation) should be
2216// handled before lowering to codegen.
2217template <typename SrcOp>
2218class IdentityNConverter : public OpRewritePattern<SrcOp> {
2219public:
2220 using OpRewritePattern<SrcOp>::OpRewritePattern;
2221
2222 LogicalResult matchAndRewrite(SrcOp op,
2223 PatternRewriter &rewriter) const final {
2224 rewriter.replaceOp(op, op.getOperation()->getOperands());
2225 return success();
2226 }
2227};
2228
2229template <typename SrcOp>
2230class ReduceConverter : public OpRewritePattern<SrcOp> {
2231public:
2232 using OpRewritePattern<SrcOp>::OpRewritePattern;
2233
2234 LogicalResult matchAndRewrite(SrcOp reduceOp,
2235 PatternRewriter &rewriter) const final {
2236 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2237 }
2238};
2239
2240class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2241public:
2242 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2243
2244 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2245 PatternRewriter &rewriter) const final {
2246 auto loc = op.getLoc();
2247 Value input = op.getInput1();
2248 auto inputTy = cast<ShapedType>(input.getType());
2249 auto resultTy = cast<ShapedType>(op.getType());
2250 auto axis = op.getAxis();
2251
2252 SmallVector<Value> dynDims;
2253 for (int i = 0; i < inputTy.getRank(); i++) {
2254 if (inputTy.isDynamicDim(i)) {
2255 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2256 }
2257 }
2258
2259 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2260
2261 // First fill the output buffer with the init value.
2262 auto emptyTensor = tensor::EmptyOp::create(
2263 rewriter, loc, inputTy.getShape(),
2264 inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2265 .getResult();
2266 SmallVector<AffineMap, 2> affineMaps = {
2267 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2268
2269 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2270 op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
2271 getNParallelLoopsAttrs(resultTy.getRank()),
2272 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2273 llvm::SmallVector<Value> indices;
2274 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2275 Value index =
2276 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2277 if (i == axis) {
2278 auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1);
2279 auto sizeMinusOne =
2280 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2281 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2282 index);
2283 }
2284
2285 indices.push_back(index);
2286 }
2287
2288 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2289 input, indices);
2290 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2291 extract.getResult());
2292 });
2293 return success();
2294 }
2295};
2296
2297// This converter translate a tile operation to a reshape, broadcast, reshape.
2298// The first reshape minimally expands each tiled dimension to include a
2299// proceding size-1 dim. This dim is then broadcasted to the appropriate
2300// multiple.
2301struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2302 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2303
2304 LogicalResult
2305 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2306 ConversionPatternRewriter &rewriter) const override {
2307 auto loc = op.getLoc();
2308 auto input = op.getInput1();
2309 auto inputTy = cast<ShapedType>(input.getType());
2310 auto inputShape = inputTy.getShape();
2311 auto resultTy = cast<ShapedType>(op.getType());
2312 auto elementTy = inputTy.getElementType();
2313 int64_t rank = inputTy.getRank();
2314
2315 SmallVector<int64_t> multiples;
2316 if (failed(op.getConstantMultiples(multiples)))
2317 return failure();
2318
2319 // Broadcast the newly added dimensions to their appropriate multiple.
2320 SmallVector<int64_t, 2> genericShape;
2321 for (int i = 0; i < rank; i++) {
2322 int64_t dim = multiples[i];
2323 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2324 genericShape.push_back(inputShape[i]);
2325 }
2326
2327 SmallVector<Value> dynDims;
2328 for (int i = 0; i < inputTy.getRank(); i++) {
2329 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2330 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2331 }
2332 }
2333
2334 auto emptyTensor = tensor::EmptyOp::create(
2335 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2336
2337 // We needs to map the input shape to the non-broadcasted dimensions.
2338 SmallVector<AffineExpr, 4> dimExprs;
2339 dimExprs.reserve(rank);
2340 for (unsigned i = 0; i < rank; ++i)
2341 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2342
2343 auto readAffineMap =
2344 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2345 rewriter.getContext());
2346
2347 SmallVector<AffineMap, 2> affineMaps = {
2348 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2349
2350 auto genericOp = linalg::GenericOp::create(
2351 rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2352 ValueRange{emptyTensor}, affineMaps,
2353 getNParallelLoopsAttrs(genericShape.size()),
2354 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2355 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2356 });
2357
2358 auto shapeValue = getTosaConstShape(
2359 rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
2360 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2361 op, resultTy, genericOp.getResult(0), shapeValue);
2362 return success();
2363 }
2364};
2365
2366// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2367// op, producing two output buffers.
2368//
2369// The first output buffer contains the index of the found maximum value. It is
2370// initialized to 0 and is resulting integer type.
2371//
2372// The second output buffer contains the maximum value found. It is initialized
2373// to the minimum representable value of the input element type. After being
2374// populated by indexed_generic, this buffer is disgarded as only the index is
2375// requested.
2376//
2377// The indexed_generic op updates both the maximum value and index if the
2378// current value exceeds the running max.
2379class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2380public:
2381 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2382
2383 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2384 PatternRewriter &rewriter) const final {
2385 auto loc = argmaxOp.getLoc();
2386 Value input = argmaxOp.getInput();
2387 auto inputTy = cast<ShapedType>(input.getType());
2388 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2389 auto inElementTy = inputTy.getElementType();
2390 auto outElementTy = resultTy.getElementType();
2391 int axis = argmaxOp.getAxis();
2392 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2393
2394 if (!isa<IntegerType>(outElementTy))
2395 return rewriter.notifyMatchFailure(
2396 argmaxOp,
2397 "tosa.arg_max to linalg.* requires integer-like result type");
2398
2399 SmallVector<Value> dynDims;
2400 for (int i = 0; i < inputTy.getRank(); i++) {
2401 if (inputTy.isDynamicDim(i) && i != axis) {
2402 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2403 }
2404 }
2405
2406 // First fill the output buffer for the index.
2407 auto emptyTensorIdx =
2408 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2409 outElementTy, dynDims)
2410 .getResult();
2411 auto fillValueIdx = arith::ConstantOp::create(
2412 rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2413 auto filledTensorIdx =
2414 linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
2415 ValueRange{emptyTensorIdx})
2416 .result();
2417
2418 // Second fill the output buffer for the running max.
2419 auto emptyTensorMax =
2420 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2421 dynDims)
2422 .getResult();
2423 auto fillValueMaxAttr =
2424 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2425
2426 if (!fillValueMaxAttr)
2427 return rewriter.notifyMatchFailure(
2428 argmaxOp, "unsupported tosa.argmax element type");
2429
2430 auto fillValueMax =
2431 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2432 auto filledTensorMax =
2433 linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
2434 ValueRange{emptyTensorMax})
2435 .result();
2436
2437 // We need to reduce along the arg-max axis, with parallel operations along
2438 // the rest.
2439 SmallVector<utils::IteratorType, 4> iteratorTypes;
2440 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2441 iteratorTypes[axis] = utils::IteratorType::reduction;
2442
2443 SmallVector<AffineExpr, 2> srcExprs;
2444 SmallVector<AffineExpr, 2> dstExprs;
2445 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2446 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2447 if (axis != i)
2448 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2449 }
2450
2451 bool didEncounterError = false;
2452 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2453 rewriter.getContext());
2454 auto linalgOp = linalg::GenericOp::create(
2455 rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2456 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2457 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2458 ValueRange blockArgs) {
2459 auto newValue = blockArgs[0];
2460 auto oldIndex = blockArgs[1];
2461 auto oldValue = blockArgs[2];
2462
2463 Value newIndex = arith::IndexCastOp::create(
2464 rewriter, nestedLoc, oldIndex.getType(),
2465 linalg::IndexOp::create(rewriter, loc, axis));
2466
2467 Value predicate;
2468 if (isa<FloatType>(inElementTy)) {
2469 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2470 // Only update index & max value for non NaN values. If all
2471 // values are NaNs, the initial index will be return which is 0.
2472 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2473 arith::CmpFPredicate::OGT,
2474 newValue, oldValue);
2475 } else {
2476 // Update max value if either of the following is true:
2477 // - new value is bigger
2478 // - cur max is not NaN and new value is NaN
2479 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2480 arith::CmpFPredicate::UGT,
2481 newValue, oldValue);
2482 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2483 arith::CmpFPredicate::ORD,
2484 oldValue, oldValue);
2485 predicate = arith::AndIOp::create(
2486 rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2487 }
2488 } else if (isa<IntegerType>(inElementTy)) {
2489 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2490 arith::CmpIPredicate::sgt,
2491 newValue, oldValue);
2492 } else {
2493 didEncounterError = true;
2494 return;
2495 }
2496
2497 auto resultMax = arith::SelectOp::create(
2498 rewriter, nestedLoc, predicate, newValue, oldValue);
2499 auto resultIndex = arith::SelectOp::create(
2500 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2501 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2502 ValueRange({resultIndex, resultMax}));
2503 });
2504
2505 if (didEncounterError)
2506 return rewriter.notifyMatchFailure(
2507 argmaxOp, "unsupported tosa.argmax element type");
2508
2509 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2510 return success();
2511 }
2512};
2513
2514class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2515public:
2516 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2517 LogicalResult
2518 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2519 ConversionPatternRewriter &rewriter) const final {
2520 auto input = adaptor.getOperands()[0];
2521 auto indices = adaptor.getOperands()[1];
2522
2523 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2524 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2525 if (!valuesTy || !resultTy)
2526 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2527
2528 auto dynamicDims = inferDynamicDimsForGather(
2529 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2530
2531 auto resultElementTy = resultTy.getElementType();
2532
2533 auto loc = op.getLoc();
2534 auto emptyTensor =
2535 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2536 resultElementTy, dynamicDims)
2537 .getResult();
2538
2539 SmallVector<AffineMap, 2> affineMaps = {
2541 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2542 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2543 rewriter.getContext()),
2544 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2545
2546 auto genericOp = linalg::GenericOp::create(
2547 rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2548 ValueRange{emptyTensor}, affineMaps,
2549 getNParallelLoopsAttrs(resultTy.getRank()),
2550 [&](OpBuilder &b, Location loc, ValueRange args) {
2551 auto indexValue = args[0];
2552 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2553 Value index1 = arith::IndexCastOp::create(
2554 rewriter, loc, rewriter.getIndexType(), indexValue);
2555 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2556 Value extract = tensor::ExtractOp::create(
2557 rewriter, loc, input, ValueRange{index0, index1, index2});
2558 linalg::YieldOp::create(rewriter, loc, extract);
2559 });
2560 rewriter.replaceOp(op, genericOp.getResult(0));
2561 return success();
2562 }
2563
2564 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2565 Location loc,
2566 Value values,
2567 Value indices) {
2568 llvm::SmallVector<Value> results;
2569
2570 auto addDynamicDimension = [&](Value source, int64_t dim) {
2571 auto sz = tensor::getMixedSize(builder, loc, source, dim);
2572 if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2573 results.push_back(dimValue);
2574 };
2575
2576 addDynamicDimension(values, 0);
2577 addDynamicDimension(indices, 1);
2578 addDynamicDimension(values, 2);
2579 return results;
2580 }
2581};
2582
2583// Lowerings the TableOp to a series of gathers and numerica operations. This
2584// includes interpolation between the high/low values. For the I8 varient, this
2585// simplifies to a single gather operation.
2586class TableConverter : public OpRewritePattern<tosa::TableOp> {
2587public:
2588 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2589
2590 LogicalResult matchAndRewrite(tosa::TableOp op,
2591 PatternRewriter &rewriter) const final {
2592 auto loc = op.getLoc();
2593 Value input = op.getInput1();
2594 Value table = op.getTable();
2595 auto inputTy = cast<ShapedType>(input.getType());
2596 auto tableTy = cast<ShapedType>(table.getType());
2597 auto resultTy = cast<ShapedType>(op.getType());
2598
2599 auto inputElementTy = inputTy.getElementType();
2600 auto tableElementTy = tableTy.getElementType();
2601 auto resultElementTy = resultTy.getElementType();
2602
2603 SmallVector<Value> dynDims;
2604 for (int i = 0; i < resultTy.getRank(); ++i) {
2605 if (inputTy.isDynamicDim(i)) {
2606 dynDims.push_back(
2607 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2608 }
2609 }
2610
2611 auto emptyTensor =
2612 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2613 resultElementTy, dynDims)
2614 .getResult();
2615
2616 SmallVector<AffineMap, 2> affineMaps = {
2617 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2618 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2619
2620 auto genericOp = linalg::GenericOp::create(
2621 rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor},
2622 affineMaps, getNParallelLoopsAttrs(resultTy.getRank()));
2623 rewriter.replaceOp(op, genericOp.getResult(0));
2624
2625 {
2626 OpBuilder::InsertionGuard regionGuard(rewriter);
2627 Block *block = rewriter.createBlock(
2628 &genericOp.getRegion(), genericOp.getRegion().end(),
2629 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2630
2631 auto inputValue = block->getArgument(0);
2632 rewriter.setInsertionPointToStart(block);
2633 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2634 resultElementTy.isInteger(8)) {
2635 Value index = arith::IndexCastOp::create(
2636 rewriter, loc, rewriter.getIndexType(), inputValue);
2637 Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128);
2638 index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2639 index, offset);
2640 Value extract =
2641 tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2642 linalg::YieldOp::create(rewriter, loc, extract);
2643 return success();
2644 }
2645
2646 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2647 resultElementTy.isInteger(32)) {
2648 Value extend = arith::ExtSIOp::create(
2649 rewriter, loc, rewriter.getI32Type(), inputValue);
2650
2651 auto offset = arith::ConstantOp::create(
2652 rewriter, loc, rewriter.getI32IntegerAttr(32768));
2653 auto seven = arith::ConstantOp::create(rewriter, loc,
2654 rewriter.getI32IntegerAttr(7));
2655 auto one = arith::ConstantOp::create(rewriter, loc,
2656 rewriter.getI32IntegerAttr(1));
2657 auto b1111111 = arith::ConstantOp::create(
2658 rewriter, loc, rewriter.getI32IntegerAttr(127));
2659
2660 // Compute the index and fractional part from the input value:
2661 // value = value + 32768
2662 // index = value >> 7;
2663 // fraction = 0x01111111 & value
2664 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2665 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2666 Value fraction =
2667 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2668
2669 // Extract the base and next values from the table.
2670 // base = (int32_t) table[index];
2671 // next = (int32_t) table[index + 1];
2672 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2673
2674 index = arith::IndexCastOp::create(rewriter, loc,
2675 rewriter.getIndexType(), index);
2676 indexPlusOne = arith::IndexCastOp::create(
2677 rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2678
2679 Value base =
2680 tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2681 Value next = tensor::ExtractOp::create(rewriter, loc, table,
2682 ValueRange{indexPlusOne});
2683
2684 base =
2685 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2686 next =
2687 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2688
2689 // Use the fractional part to interpolate between the input values:
2690 // result = (base << 7) + (next - base) * fraction
2691 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2692 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2693 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2694 Value result =
2695 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2696
2697 linalg::YieldOp::create(rewriter, loc, result);
2698
2699 return success();
2700 }
2701 }
2702
2703 return rewriter.notifyMatchFailure(
2704 op, "unable to create body for tosa.table op");
2705 }
2706};
2707
2708struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2709 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2710
2711 static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2712
2713 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2714 OpFoldResult ofr) {
2715 auto one = arith::ConstantIndexOp::create(builder, loc, 1);
2716 auto two = arith::ConstantIndexOp::create(builder, loc, 2);
2717
2718 auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2719 auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2720 auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2721 return getAsOpFoldResult(plusOne);
2722 }
2723
2724 static RankedTensorType
2725 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2726 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2727 // Get [N, H, W]
2728 auto dims = tensor::getMixedSizes(builder, loc, input);
2729
2730 // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2731 // output tensors.
2732 dims[2] = halfPlusOne(builder, loc, dims[2]);
2733
2734 llvm::SmallVector<int64_t, 3> staticSizes;
2735 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2736
2737 auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2738 return RankedTensorType::get(staticSizes, elementType);
2739 }
2740
2741 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2742 RankedTensorType type,
2743 llvm::ArrayRef<Value> dynamicSizes) {
2744 auto emptyTensor =
2745 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2746 auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2747 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2748 auto filledTensor =
2749 linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
2750 ValueRange{emptyTensor})
2751 .result();
2752 return filledTensor;
2753 }
2754
2755 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2756 FloatType type, Value value) {
2757 auto integerVal = arith::IndexCastUIOp::create(
2758 builder, loc,
2759 type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2760 : builder.getI32Type(),
2761 value);
2762
2763 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2764 }
2765
2766 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2767 FloatType type, int64_t index) {
2768 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2769 return castIndexToFloat(builder, loc, type, indexVal);
2770 }
2771
2772 template <typename... Args>
2773 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2774 Args... args) {
2775 return {builder.getAffineDimExpr(args)...};
2776 }
2777
2778 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2779 PatternRewriter &rewriter) const override {
2780 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2781 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2782 return rewriter.notifyMatchFailure(rfft2d,
2783 "only supports ranked tensors");
2784 }
2785
2786 auto loc = rfft2d.getLoc();
2787 auto input = rfft2d.getInputReal();
2788 auto elementType =
2789 dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2790 if (!elementType)
2791 return rewriter.notifyMatchFailure(rfft2d,
2792 "only supports float element types");
2793
2794 // Compute the output type and set of dynamic sizes
2795 llvm::SmallVector<Value> dynamicSizes;
2796 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2797
2798 // Iterator types for the linalg.generic implementation
2799 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2800 utils::IteratorType::parallel, utils::IteratorType::parallel,
2801 utils::IteratorType::parallel, utils::IteratorType::reduction,
2802 utils::IteratorType::reduction};
2803
2804 // Inputs/outputs to the linalg.generic implementation
2805 llvm::SmallVector<Value> genericOpInputs = {input};
2806 llvm::SmallVector<Value> genericOpOutputs = {
2807 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2808 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2809
2810 // Indexing maps for input and output tensors
2811 auto indexingMaps = AffineMap::inferFromExprList(
2812 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2813 affineDimsExpr(rewriter, 0, 1, 2),
2814 affineDimsExpr(rewriter, 0, 1, 2)},
2815 rewriter.getContext());
2816
2817 // Width and height dimensions of the original input.
2818 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2819 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2820
2821 // Constants and dimension sizes
2822 auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2823 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2824 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2825 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2826
2827 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2828 Value valReal = args[0];
2829 Value sumReal = args[1];
2830 Value sumImag = args[2];
2831
2832 // Indices for angle computation
2833 Value oy = linalg::IndexOp::create(builder, loc, 1);
2834 Value ox = linalg::IndexOp::create(builder, loc, 2);
2835 Value iy = linalg::IndexOp::create(builder, loc, 3);
2836 Value ix = linalg::IndexOp::create(builder, loc, 4);
2837
2838 // Calculating angle without integer parts of components as sin/cos are
2839 // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2840 // / W);
2841 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2842 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2843
2844 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2845 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2846
2847 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2848 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2849
2850 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2851 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2852 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2853 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2854
2855 // realComponent = valReal * cos(angle)
2856 // imagComponent = valReal * sin(angle)
2857 auto cosAngle = math::CosOp::create(builder, loc, angle);
2858 auto sinAngle = math::SinOp::create(builder, loc, angle);
2859 auto realComponent =
2860 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2861 auto imagComponent =
2862 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2863
2864 // outReal = sumReal + realComponent
2865 // outImag = sumImag - imagComponent
2866 auto outReal =
2867 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2868 auto outImag =
2869 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2870
2871 linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
2872 };
2873
2874 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2875 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2876 indexingMaps, iteratorTypes, buildBody);
2877
2878 return success();
2879 }
2880};
2881
2882struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2884
2885 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2886 PatternRewriter &rewriter) const override {
2887 if (!llvm::all_of(fft2d->getOperandTypes(),
2888 RFFT2dConverter::isRankedTensor) ||
2889 !llvm::all_of(fft2d->getResultTypes(),
2890 RFFT2dConverter::isRankedTensor)) {
2891 return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2892 }
2893
2894 Location loc = fft2d.getLoc();
2895 Value input_real = fft2d.getInputReal();
2896 Value input_imag = fft2d.getInputImag();
2897 BoolAttr inverse = fft2d.getInverseAttr();
2898
2899 auto real_el_ty = cast<FloatType>(
2900 cast<ShapedType>(input_real.getType()).getElementType());
2901 [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2902 cast<ShapedType>(input_imag.getType()).getElementType());
2903
2904 assert(real_el_ty == imag_el_ty);
2905
2906 // Compute the output type and set of dynamic sizes
2907 SmallVector<Value> dynamicSizes;
2908
2909 // Get [N, H, W]
2910 auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2911
2912 SmallVector<int64_t, 3> staticSizes;
2913 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2914
2915 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2916
2917 // Iterator types for the linalg.generic implementation
2918 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2919 utils::IteratorType::parallel, utils::IteratorType::parallel,
2920 utils::IteratorType::parallel, utils::IteratorType::reduction,
2921 utils::IteratorType::reduction};
2922
2923 // Inputs/outputs to the linalg.generic implementation
2924 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2925 SmallVector<Value> genericOpOutputs = {
2926 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2927 dynamicSizes),
2928 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2929 dynamicSizes)};
2930
2931 // Indexing maps for input and output tensors
2932 auto indexingMaps = AffineMap::inferFromExprList(
2933 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2934 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2935 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2936 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2937 rewriter.getContext());
2938
2939 // Width and height dimensions of the original input.
2940 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2941 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2942
2943 // Constants and dimension sizes
2944 auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2945 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2946 Value constH =
2947 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2948 Value constW =
2949 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2950
2951 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2952 Value valReal = args[0];
2953 Value valImag = args[1];
2954 Value sumReal = args[2];
2955 Value sumImag = args[3];
2956
2957 // Indices for angle computation
2958 Value oy = linalg::IndexOp::create(builder, loc, 1);
2959 Value ox = linalg::IndexOp::create(builder, loc, 2);
2960 Value iy = linalg::IndexOp::create(builder, loc, 3);
2961 Value ix = linalg::IndexOp::create(builder, loc, 4);
2962
2963 // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2964 // ox) % W ) / W);
2965 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2966 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2967
2968 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2969 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2970
2971 auto iyRemFloat =
2972 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2973 auto ixRemFloat =
2974 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2975
2976 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2977 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2978
2979 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2980 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2981
2982 if (inverse.getValue()) {
2983 angle = arith::MulFOp::create(
2984 builder, loc, angle,
2985 arith::ConstantOp::create(rewriter, loc,
2986 rewriter.getFloatAttr(real_el_ty, -1.0)));
2987 }
2988
2989 // realComponent = val_real * cos(a) + val_imag * sin(a);
2990 // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2991 auto cosAngle = math::CosOp::create(builder, loc, angle);
2992 auto sinAngle = math::SinOp::create(builder, loc, angle);
2993
2994 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
2995 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
2996 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
2997
2998 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
2999 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3000
3001 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3002
3003 // outReal = sumReal + realComponent
3004 // outImag = sumImag - imagComponent
3005 auto outReal =
3006 arith::AddFOp::create(builder, loc, sumReal, realComponent);
3007 auto outImag =
3008 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3009
3010 linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
3011 };
3012
3013 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
3014 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3015 indexingMaps, iteratorTypes, buildBody);
3016
3017 return success();
3018 }
3019};
3020
3021} // namespace
3022
3024 const TypeConverter &converter, RewritePatternSet *patterns) {
3025
3026 // We have multiple resize coverters to handle degenerate cases.
3027 patterns->add<GenericResizeConverter>(patterns->getContext(),
3028 /*benefit=*/100);
3029 patterns->add<ResizeUnaryConverter>(patterns->getContext(),
3030 /*benefit=*/200);
3031 patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
3032 /*benefit=*/300);
3033
3034 patterns->add<
3035 // clang-format off
3036 PointwiseConverter<tosa::AddOp>,
3037 PointwiseConverter<tosa::SubOp>,
3038 PointwiseConverter<tosa::MulOp>,
3039 PointwiseConverter<tosa::IntDivOp>,
3040 PointwiseConverter<tosa::NegateOp>,
3041 PointwiseConverter<tosa::PowOp>,
3042 PointwiseConverter<tosa::ReciprocalOp>,
3043 PointwiseConverter<tosa::RsqrtOp>,
3044 PointwiseConverter<tosa::LogOp>,
3045 PointwiseConverter<tosa::ExpOp>,
3046 PointwiseConverter<tosa::AbsOp>,
3047 PointwiseConverter<tosa::SinOp>,
3048 PointwiseConverter<tosa::CosOp>,
3049 PointwiseConverter<tosa::TanhOp>,
3050 PointwiseConverter<tosa::ErfOp>,
3051 PointwiseConverter<tosa::BitwiseAndOp>,
3052 PointwiseConverter<tosa::BitwiseOrOp>,
3053 PointwiseConverter<tosa::BitwiseNotOp>,
3054 PointwiseConverter<tosa::BitwiseXorOp>,
3055 PointwiseConverter<tosa::LogicalAndOp>,
3056 PointwiseConverter<tosa::LogicalNotOp>,
3057 PointwiseConverter<tosa::LogicalOrOp>,
3058 PointwiseConverter<tosa::LogicalXorOp>,
3059 PointwiseConverter<tosa::CastOp>,
3060 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3061 PointwiseConverter<tosa::LogicalRightShiftOp>,
3062 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3063 PointwiseConverter<tosa::ClzOp>,
3064 PointwiseConverter<tosa::SelectOp>,
3065 PointwiseConverter<tosa::GreaterOp>,
3066 PointwiseConverter<tosa::GreaterEqualOp>,
3067 PointwiseConverter<tosa::EqualOp>,
3068 PointwiseConverter<tosa::MaximumOp>,
3069 PointwiseConverter<tosa::MinimumOp>,
3070 PointwiseConverter<tosa::CeilOp>,
3071 PointwiseConverter<tosa::FloorOp>,
3072 PointwiseConverter<tosa::ClampOp>,
3073 PointwiseConverter<tosa::SigmoidOp>
3074 >(converter, patterns->getContext());
3075
3076 patterns->add<
3077 IdentityNConverter<tosa::IdentityOp>,
3078 ReduceConverter<tosa::ReduceAllOp>,
3079 ReduceConverter<tosa::ReduceAnyOp>,
3080 ReduceConverter<tosa::ReduceMinOp>,
3081 ReduceConverter<tosa::ReduceMaxOp>,
3082 ReduceConverter<tosa::ReduceSumOp>,
3083 ReduceConverter<tosa::ReduceProductOp>,
3084 ArgMaxConverter,
3085 GatherConverter,
3086 RescaleConverter,
3087 ReverseConverter,
3088 RFFT2dConverter,
3089 FFT2dConverter,
3090 TableConverter,
3091 TileConverter>(patterns->getContext());
3092 // clang-format on
3093}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
DenseMap< int64_t, Value > IndexPool
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
BlockArgument getArgument(unsigned i)
Definition Block.h:129
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
FloatType getF32Type()
Definition Builders.cpp:43
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
MLIRContext * getContext() const
Definition Builders.h:56
IntegerAttr getI8IntegerAttr(int8_t value)
Definition Builders.cpp:221
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...